import logging import boto3 logger = logging.getLogger(__name__) sm_client = boto3.client("sagemaker") def get_models_descriptions(model_name: str, stage: str = None) -> list: """ Get the model description from the SageMaker API Args: model_name: the name of the model stage: the name of the stage environment metadata in MLflow Returns: A list of dictionaries with the model description """ model_name_aws_friendly = model_name.replace("_", "-") paginator = sm_client.get_paginator("list_model_packages") try: descriptions_list = [ sm_client.describe_model_package(ModelPackageName=j["ModelPackageArn"]) for k in paginator.paginate(ModelPackageGroupName=model_name_aws_friendly) for j in k["ModelPackageSummaryList"] ] except: logger.exception("failed to retrieve list of models") return [] if stage is not None: descriptions_list = [ k for k in descriptions_list if k["CustomerMetadataProperties"]["mlflow_current_stage"] == stage ] return descriptions_list