# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Placeholder docstring""" from __future__ import absolute_import, print_function import json import logging import os import re import sys import time import typing import warnings import uuid from typing import List, Dict, Any, Sequence, Optional import boto3 import botocore import botocore.config from botocore.exceptions import ClientError import six from sagemaker.utils import instance_supports_kms import sagemaker.logs from sagemaker import vpc_utils, s3_utils from sagemaker._studio import _append_project_tags from sagemaker.config import load_sagemaker_config, validate_sagemaker_config from sagemaker.config import ( KEY, TRAINING_JOB, TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, TRAINING_JOB_ENVIRONMENT_PATH, TRAINING_JOB_VPC_CONFIG_PATH, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, TRAINING_JOB_RESOURCE_CONFIG_PATH, TRAINING_JOB_PROFILE_CONFIG_PATH, PROCESSING_JOB_INPUTS_PATH, PROCESSING_JOB, PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, PROCESSING_JOB_ENVIRONMENT_PATH, PROCESSING_JOB_ROLE_ARN_PATH, PROCESSING_JOB_NETWORK_CONFIG_PATH, PROCESSING_OUTPUT_CONFIG_PATH, PROCESSING_JOB_PROCESSING_RESOURCES_PATH, MONITORING_JOB_ENVIRONMENT_PATH, MONITORING_JOB_ROLE_ARN_PATH, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, MONITORING_JOB_NETWORK_CONFIG_PATH, MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, MONITORING_SCHEDULE, MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, AUTO_ML_ROLE_ARN_PATH, AUTO_ML_OUTPUT_CONFIG_PATH, AUTO_ML_JOB_CONFIG_PATH, AUTO_ML_JOB, COMPILATION_JOB_ROLE_ARN_PATH, COMPILATION_JOB_OUTPUT_CONFIG_PATH, COMPILATION_JOB_VPC_CONFIG_PATH, COMPILATION_JOB, EDGE_PACKAGING_ROLE_ARN_PATH, EDGE_PACKAGING_OUTPUT_CONFIG_PATH, EDGE_PACKAGING_RESOURCE_KEY_PATH, EDGE_PACKAGING_JOB, TRANSFORM_JOB, TRANSFORM_JOB_ENVIRONMENT_PATH, TRANSFORM_JOB_KMS_KEY_ID_PATH, TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, VOLUME_KMS_KEY_ID, TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH, MODEL, MODEL_CONTAINERS_PATH, MODEL_EXECUTION_ROLE_ARN_PATH, MODEL_ENABLE_NETWORK_ISOLATION_PATH, MODEL_PRIMARY_CONTAINER_PATH, MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH, MODEL_VPC_CONFIG_PATH, MODEL_PACKAGE_VALIDATION_ROLE_PATH, VALIDATION_ROLE, VALIDATION_PROFILES, MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH, MODEL_PACKAGE_VALIDATION_PROFILES_PATH, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, KMS_KEY_ID, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, ENDPOINT_CONFIG, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, ENDPOINT, SAGEMAKER, FEATURE_GROUP, TAGS, FEATURE_GROUP_ROLE_ARN_PATH, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, ) from sagemaker.config.config_utils import _log_sagemaker_config_merge from sagemaker.deprecations import deprecated_class from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig from sagemaker.user_agent import prepend_user_agent from sagemaker.utils import ( name_from_image, secondary_training_status_changed, secondary_training_status_message, sts_regional_endpoint, retries, resolve_value_from_config, get_sagemaker_config_value, resolve_class_attribute_from_config, resolve_nested_dict_value_from_config, update_nested_dictionary_with_values_from_config, update_list_of_dicts_with_values_from_config, ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings LOGGER = logging.getLogger("sagemaker") NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json" _STATUS_CODE_TABLE = { "COMPLETED": "Completed", "INPROGRESS": "InProgress", "IN_PROGRESS": "InProgress", "FAILED": "Failed", "STOPPED": "Stopped", "STOPPING": "Stopping", "STARTING": "Starting", "PENDING": "Pending", } class LogState(object): """Placeholder docstring""" STARTING = 1 WAIT_IN_PROGRESS = 2 TAILING = 3 JOB_COMPLETE = 4 COMPLETE = 5 class Session(object): # pylint: disable=too-many-public-methods """Manage interactions with the Amazon SageMaker APIs and any other AWS services needed. This class provides convenient methods for manipulating entities and resources that Amazon SageMaker uses, such as training jobs, endpoints, and input datasets in S3. AWS service calls are delegated to an underlying Boto3 session, which by default is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default bucket based on a naming convention which includes the current AWS account ID. """ def __init__( self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None, sagemaker_featurestore_runtime_client=None, default_bucket=None, settings=SessionSettings(), sagemaker_metrics_client=None, sagemaker_config: dict = None, default_bucket_prefix: str = None, ): """Initialize a SageMaker ``Session``. Args: boto_session (boto3.session.Session): The underlying Boto3 session which AWS service calls are delegated to (default: None). If not provided, one is created with default AWS configuration chain. sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service calls other than ``InvokeEndpoint`` (default: None). Estimators created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client): Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker (default: None). If not provided, one will be created using this instance's ``boto_session``. default_bucket (str): The default Amazon S3 bucket to be used by this session. This will be created the next time an Amazon S3 bucket is needed (by calling :func:`default_bucket`). If not provided, it will be fetched from the sagemaker_config. If not configured there either, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. sagemaker_metrics_client (boto3.SageMakerMetrics.Client): Client which makes SageMaker Metrics related calls to Amazon SageMaker (default: None). If not provided, one will be created using this instance's ``boto_session``. sagemaker_config (dict): A dictionary containing default values for the SageMaker Python SDK. (default: None). The dictionary must adhere to the schema defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`. If sagemaker_config is not provided and configuration files exist (at the default paths for admins and users, or paths set through the environment variables SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE), a new dictionary will be generated from those configuration files. Alternatively, this dictionary can be generated by calling :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default: None). If provided and where applicable, it will be used by the SDK to construct default S3 URIs, in the format: `s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>` This parameter can also be specified via `{sagemaker_config}` instead of here. If not provided here or within `{sagemaker_config}`, default S3 URIs will have the format: `s3://{default_bucket}/<rest of object key>` """ # sagemaker_config is validated and initialized inside :func:`_initialize`, # so if default_bucket is None and the sagemaker_config has a default S3 bucket configured, # _default_bucket_name_override will be set again inside :func:`_initialize`. self._default_bucket = None self._default_bucket_name_override = default_bucket # this may also be set again inside :func:`_initialize` if it is None self.default_bucket_prefix = default_bucket_prefix self.s3_resource = None self.s3_client = None self.resource_groups_client = None self.resource_group_tagging_client = None self.config = None self.lambda_client = None self.settings = settings self._initialize( boto_session=boto_session, sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, sagemaker_metrics_client=sagemaker_metrics_client, sagemaker_config=sagemaker_config, ) def _initialize( self, boto_session, sagemaker_client, sagemaker_runtime_client, sagemaker_featurestore_runtime_client, sagemaker_metrics_client, sagemaker_config: dict = None, ): """Initialize this SageMaker Session. Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. Sets the region_name. """ self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session() self._region_name = self.boto_session.region_name if self._region_name is None: raise ValueError( "Must setup local AWS configuration with a region supported by SageMaker." ) self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker") prepend_user_agent(self.sagemaker_client) if sagemaker_runtime_client is not None: self.sagemaker_runtime_client = sagemaker_runtime_client else: config = botocore.config.Config(read_timeout=80) self.sagemaker_runtime_client = self.boto_session.client( "runtime.sagemaker", config=config ) prepend_user_agent(self.sagemaker_runtime_client) if sagemaker_featurestore_runtime_client: self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client else: self.sagemaker_featurestore_runtime_client = self.boto_session.client( "sagemaker-featurestore-runtime" ) if sagemaker_metrics_client: self.sagemaker_metrics_client = sagemaker_metrics_client else: self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") prepend_user_agent(self.sagemaker_metrics_client) self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name) self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name) self.local_mode = False if sagemaker_config: validate_sagemaker_config(sagemaker_config) self.sagemaker_config = sagemaker_config else: # self.s3_resource might be None. If it is None, load_sagemaker_config will # create a default S3 resource, but only if it needs to fetch from S3 self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource) # after sagemaker_config initialization, update self._default_bucket_name_override if needed self._default_bucket_name_override = resolve_value_from_config( direct_input=self._default_bucket_name_override, config_path=SESSION_DEFAULT_S3_BUCKET_PATH, sagemaker_session=self, ) # after sagemaker_config initialization, update self.default_bucket_prefix if needed self.default_bucket_prefix = resolve_value_from_config( direct_input=self.default_bucket_prefix, config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, sagemaker_session=self, ) @property def boto_region_name(self): """Placeholder docstring""" return self._region_name def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): """Upload local file or directory to S3. If a single file is specified for upload, the resulting S3 object key is ``{key_prefix}/{filename}`` (filename does not include the local path, if any specified). If a directory is specified for upload, the API uploads all content, recursively, preserving relative structure of subdirectories. The resulting object key names are: ``{key_prefix}/{relative_subdirectory_path}/filename``. Args: path (str): Path (absolute or relative) of local file or directory to upload. bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the default bucket of the ``Session`` is used (if default bucket does not exist, the ``Session`` creates it). key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the prefix to create a directory structure for the bucket content that it display in the S3 console. extra_args (dict): Optional extra arguments that may be passed to the upload operation. Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the ExtraArgs parameter documentation here: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter Returns: str: The S3 URI of the uploaded file(s). If a file is specified in the path argument, the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``. If a directory is specified in the path argument, the URI format is ``s3://{bucket name}/{key_prefix}``. """ bucket, key_prefix = s3_utils.determine_bucket_and_prefix( bucket=bucket, key_prefix=key_prefix, sagemaker_session=self ) # Generate a tuple for each file that we want to upload of the form (local_path, s3_key). files = [] key_suffix = None if os.path.isdir(path): for dirpath, _, filenames in os.walk(path): for name in filenames: local_path = os.path.join(dirpath, name) s3_relative_prefix = ( "" if path == dirpath else os.path.relpath(dirpath, start=path) + "/" ) s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name) files.append((local_path, s3_key)) else: _, name = os.path.split(path) s3_key = "{}/{}".format(key_prefix, name) files.append((path, s3_key)) key_suffix = name if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: s3 = self.s3_resource for local_path, s3_key in files: s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args) s3_uri = "s3://{}/{}".format(bucket, key_prefix) # If a specific file was used as input (instead of a directory), we return the full S3 key # of the uploaded object. This prevents unintentionally using other files under the same # prefix during training. if key_suffix: s3_uri = "{}/{}".format(s3_uri, key_suffix) return s3_uri def upload_string_as_file_body(self, body, bucket, key, kms_key=None): """Upload a string as a file body. Args: body (str): String representing the body of the file. bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the default bucket of the ``Session`` is used (if default bucket does not exist, the ``Session`` creates it). key (str): S3 object key. This is the s3 path to the file. kms_key (str): The KMS key to use for encrypting the file. Returns: str: The S3 URI of the uploaded file. The URI format is: ``s3://{bucket name}/{key}``. """ if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: s3 = self.s3_resource s3_object = s3.Object(bucket_name=bucket, key=key) if kms_key is not None: s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms") else: s3_object.put(Body=body) s3_uri = "s3://{}/{}".format(bucket, key) return s3_uri def download_data(self, path, bucket, key_prefix="", extra_args=None): """Download file or directory from S3. Args: path (str): Local path where the file or directory should be downloaded to. bucket (str): Name of the S3 Bucket to download from. key_prefix (str): Optional S3 object key name prefix. extra_args (dict): Optional extra arguments that may be passed to the download operation. Please refer to the ExtraArgs parameter in the boto3 documentation here: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html Returns: list[str]: List of local paths of downloaded files """ # Initialize the S3 client. if self.s3_client is None: s3 = self.boto_session.client("s3", region_name=self.boto_region_name) else: s3 = self.s3_client # Initialize the variables used to loop through the contents of the S3 bucket. keys = [] next_token = "" base_parameters = {"Bucket": bucket, "Prefix": key_prefix} # Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into # a "keys" list. while next_token is not None: request_parameters = base_parameters.copy() if next_token != "": request_parameters.update({"ContinuationToken": next_token}) response = s3.list_objects_v2(**request_parameters) contents = response.get("Contents", None) if not contents: LOGGER.info( "Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix ) return [] # For each object, save its key or directory. for s3_object in contents: key = s3_object.get("Key") keys.append(key) next_token = response.get("NextContinuationToken") # For each object key, create the directory on the local machine if needed, and then # download the file. downloaded_paths = [] for key in keys: tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) if not os.path.exists(os.path.dirname(destination_path)): os.makedirs(os.path.dirname(destination_path)) s3.download_file( Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args ) downloaded_paths.append(destination_path) return downloaded_paths def read_s3_file(self, bucket, key_prefix): """Read a single file from S3. Args: bucket (str): Name of the S3 Bucket to download from. key_prefix (str): S3 object key name prefix. Returns: str: The body of the s3 file as a string. """ if self.s3_client is None: s3 = self.boto_session.client("s3", region_name=self.boto_region_name) else: s3 = self.s3_client # Explicitly passing a None kms_key to boto3 throws a validation error. s3_object = s3.get_object(Bucket=bucket, Key=key_prefix) return s3_object["Body"].read().decode("utf-8") def list_s3_files(self, bucket, key_prefix): """Lists the S3 files given an S3 bucket and key. Args: bucket (str): Name of the S3 Bucket to download from. key_prefix (str): S3 object key name prefix. Returns: [str]: The list of files at the S3 path. """ if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: s3 = self.s3_resource s3_bucket = s3.Bucket(name=bucket) s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() return [s3_object.key for s3_object in s3_objects] def default_bucket(self): """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. This function will create the s3 bucket if it does not exist. Returns: str: The name of the default bucket. If the name was not explicitly specified through the Session or sagemaker_config, the bucket will take the form: ``sagemaker-{region}-{AWS account ID}``. """ if self._default_bucket: return self._default_bucket region = self.boto_session.region_name default_bucket = self._default_bucket_name_override if not default_bucket: default_bucket = generate_default_sagemaker_bucket_name(self.boto_session) self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region) self._default_bucket = default_bucket return self._default_bucket def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): """Creates an S3 Bucket if it does not exist. Also swallows a few common exceptions that indicate that the bucket already exists or that it is being created. Args: bucket_name (str): Name of the S3 bucket to be created. region (str): The region in which to create the bucket. Raises: botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket creation. If the exception is due to the bucket already existing or already being created, no exception is raised. """ if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=region) else: s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) if bucket.creation_date is None: try: # trying head bucket call s3.meta.client.head_bucket(Bucket=bucket.name) except ClientError as e: # bucket does not exist or forbidden to access error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] if error_code == "404" and message == "Not Found": # bucket does not exist, create one try: if region == "us-east-1": # 'us-east-1' cannot be specified because it is the default region: # https://github.com/boto/boto3/issues/125 s3.create_bucket(Bucket=bucket_name) else: s3.create_bucket( Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}, ) LOGGER.info("Created S3 bucket: %s", bucket_name) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] if ( error_code == "OperationAborted" and "conflicting conditional operation" in message ): # If this bucket is already being concurrently created, # we don't need to create it again. pass else: raise elif error_code == "403" and message == "Forbidden": LOGGER.error( "Bucket %s exists, but access is forbidden. Please try again after " "adding appropriate access.", bucket.name, ) raise else: raise def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. To minimize the chance of duplicate tags being applied, this is intended to be used immediately before calls to sagemaker_client, rather than during initialization of classes like EstimatorBase. Args: tags: The list of tags to append to. config_path_to_tags: The path to look up tags in the config. Returns: A list of tags. """ config_tags = get_sagemaker_config_value(self, config_path_to_tags) if config_tags is None or len(config_tags) == 0: return tags all_tags = tags or [] for config_tag in config_tags: config_tag_key = config_tag[KEY] if not any(tag.get("Key", None) == config_tag_key for tag in all_tags): # This check prevents new tags with duplicate keys from being added # (to prevent API failure and/or overwriting of tags). If there is a conflict, # the user-provided tag should take precedence over the config-provided tag. # Note: this does not check user-provided tags for conflicts with other # user-provided tags. all_tags.append(config_tag) _log_sagemaker_config_merge( source_value=tags, config_value=config_tags, merged_source_and_config_value=all_tags, config_key_path=config_path_to_tags, ) return all_tags def train( # noqa: C901 self, input_mode, input_config, role=None, job_name=None, output_config=None, resource_config=None, vpc_config=None, hyperparameters=None, stop_condition=None, tags=None, metric_definitions=None, enable_network_isolation=None, image_uri=None, training_image_config=None, container_entry_point=None, container_arguments=None, algorithm_arn=None, encrypt_inter_container_traffic=None, use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, experiment_config=None, debugger_rule_configs=None, debugger_hook_config=None, tensorboard_output_config=None, enable_sagemaker_metrics=None, profiler_rule_configs=None, profiler_config=None, environment: Optional[Dict[str, str]] = None, retry_strategy=None, ): """Create an Amazon SageMaker training job. Args: input_mode (str): The input mode that the algorithm supports. Valid modes: * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a directory in the Docker container. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of downloading the entire dataset before training begins. input_config (list): A list of Channel objects. Each channel is a named input source. Please refer to the format details described: https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. job_name (str): Name of the training job being created. output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. resource_config (dict): Contains values for ResourceConfig: * instance_count (int): Number of EC2 instances to use for training. The key in resource_config is 'InstanceCount'. * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. vpc_config (dict): Contains values for VpcConfig: * subnets (list[str]): List of subnet ids. The key in vpc_config is 'Subnets'. * security_group_ids (list[str]): List of security group ids. The key in vpc_config is 'SecurityGroupIds'. hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. tags (list[dict]): List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for the regular expression used to extract the metric from the logs. enable_network_isolation (bool): Whether to request for the training job to run with network isolation or not. image_uri (str): Docker image containing training code. training_image_config(dict): Training image configuration. Optionally, the dict can contain 'TrainingRepositoryAccessMode' and 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). For example, .. code:: python training_image_config = { "TrainingRepositoryAccessMode": "Vpc", "TrainingRepositoryAuthConfig": { "TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567890:function:test" }, } If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed through a private Docker registry in customer Vpc. If it's set to Platform or None, the training image is accessed through ECR. If TrainingRepositoryCredentialsProviderArn is provided, the credentials to authenticate to the private Docker registry will be retrieved from this AWS Lambda function. (default: ``None``). When it's set to None, SageMaker will not do authentication before pulling the image in the private Docker registry. container_entry_point (List[str]): Optional. The entrypoint script for a Docker container used to run a training job. This script takes precedence over the default train processing instructions. container_arguments (List[str]): Optional. The arguments for a container used to run a training job. algorithm_arn (str): Algorithm Arn from Marketplace. encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). use_spot_instances (bool): whether to use spot instances for training. checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. (default: ``None``). checkpoint_local_path (str): The local path that the algorithm writes its checkpoints to. SageMaker will persist all files under this path to `checkpoint_s3_uri` continually during training. On job startup the reverse happens - data from the s3 location is downloaded to this path before the algorithm is started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain four keys: 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries (default: ``None``). profiler_rule_configs (list[dict]): A list of profiler rule configurations.src/sagemaker/lineage/artifact.py:285 profiler_config (dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. * max_retry_attsmpts (int): Number of times a job should be retried. The key in RetryStrategy is 'MaxRetryAttempts'. Returns: str: ARN of the training job, if it is created. """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) ) _encrypt_inter_container_traffic = resolve_value_from_config( direct_input=encrypt_inter_container_traffic, config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, default_value=False, sagemaker_session=self, ) role = resolve_value_from_config(role, TRAINING_JOB_ROLE_ARN_PATH, sagemaker_session=self) enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, sagemaker_session=self, ) inferred_vpc_config = update_nested_dictionary_with_values_from_config( vpc_config, TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self ) inferred_output_config = update_nested_dictionary_with_values_from_config( output_config, TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, sagemaker_session=self ) customer_supplied_kms_key = "VolumeKmsKeyId" in resource_config inferred_resource_config = update_nested_dictionary_with_values_from_config( resource_config, TRAINING_JOB_RESOURCE_CONFIG_PATH, sagemaker_session=self ) inferred_profiler_config = update_nested_dictionary_with_values_from_config( profiler_config, TRAINING_JOB_PROFILE_CONFIG_PATH, sagemaker_session=self ) if ( not customer_supplied_kms_key and "InstanceType" in inferred_resource_config and not instance_supports_kms(inferred_resource_config["InstanceType"]) and "VolumeKmsKeyId" in inferred_resource_config ): del inferred_resource_config["VolumeKmsKeyId"] environment = resolve_value_from_config( direct_input=environment, config_path=TRAINING_JOB_ENVIRONMENT_PATH, default_value=None, sagemaker_session=self, ) train_request = self._get_train_request( input_mode=input_mode, input_config=input_config, role=role, job_name=job_name, output_config=inferred_output_config, resource_config=inferred_resource_config, vpc_config=inferred_vpc_config, hyperparameters=hyperparameters, stop_condition=stop_condition, tags=tags, metric_definitions=metric_definitions, enable_network_isolation=enable_network_isolation, image_uri=image_uri, training_image_config=training_image_config, container_entry_point=container_entry_point, container_arguments=container_arguments, algorithm_arn=algorithm_arn, encrypt_inter_container_traffic=_encrypt_inter_container_traffic, use_spot_instances=use_spot_instances, checkpoint_s3_uri=checkpoint_s3_uri, checkpoint_local_path=checkpoint_local_path, experiment_config=experiment_config, debugger_rule_configs=debugger_rule_configs, debugger_hook_config=debugger_hook_config, tensorboard_output_config=tensorboard_output_config, enable_sagemaker_metrics=enable_sagemaker_metrics, profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, environment=environment, retry_strategy=retry_strategy, ) def submit(request): LOGGER.info("Creating training-job with name: %s", job_name) LOGGER.debug("train request: %s", json.dumps(request, indent=4)) self.sagemaker_client.create_training_job(**request) self._intercept_create_request(train_request, submit, self.train.__name__) def _get_train_request( # noqa: C901 self, input_mode, input_config, role, job_name, output_config, resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions, enable_network_isolation=False, image_uri=None, training_image_config=None, container_entry_point=None, container_arguments=None, algorithm_arn=None, encrypt_inter_container_traffic=False, use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, experiment_config=None, debugger_rule_configs=None, debugger_hook_config=None, tensorboard_output_config=None, enable_sagemaker_metrics=None, profiler_rule_configs=None, profiler_config=None, environment=None, retry_strategy=None, ): """Constructs a request compatible for creating an Amazon SageMaker training job. Args: input_mode (str): The input mode that the algorithm supports. Valid modes: * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a directory in the Docker container. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of downloading the entire dataset before training begins. input_config (list): A list of Channel objects. Each channel is a named input source. Please refer to the format details described: https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. job_name (str): Name of the training job being created. output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. resource_config (dict): Contains values for ResourceConfig: * instance_count (int): Number of EC2 instances to use for training. The key in resource_config is 'InstanceCount'. * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. vpc_config (dict): Contains values for VpcConfig: * subnets (list[str]): List of subnet ids. The key in vpc_config is 'Subnets'. * security_group_ids (list[str]): List of security group ids. The key in vpc_config is 'SecurityGroupIds'. hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. tags (list[dict]): List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for the regular expression used to extract the metric from the logs. enable_network_isolation (bool): Whether to request for the training job to run with network isolation or not. image_uri (str): Docker image containing training code. training_image_config(dict): Training image configuration. Optionally, the dict can contain 'TrainingRepositoryAccessMode' and 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). For example, .. code:: python training_image_config = { "TrainingRepositoryAccessMode": "Vpc", "TrainingRepositoryAuthConfig": { "TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567890:function:test" }, } If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed through a private Docker registry in customer Vpc. If it's set to Platform or None, the training image is accessed through ECR. If TrainingRepositoryCredentialsProviderArn is provided, the credentials to authenticate to the private Docker registry will be retrieved from this AWS Lambda function. (default: ``None``). When it's set to None, SageMaker will not do authentication before pulling the image in the private Docker registry. container_entry_point (List[str]): Optional. The entrypoint script for a Docker container used to run a training job. This script takes precedence over the default train processing instructions. container_arguments (List[str]): Optional. The arguments for a container used to run a training job. algorithm_arn (str): Algorithm Arn from Marketplace. encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). use_spot_instances (bool): whether to use spot instances for training. checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. (default: ``None``). checkpoint_local_path (str): The local path that the algorithm writes its checkpoints to. SageMaker will persist all files under this path to `checkpoint_s3_uri` continually during training. On job startup the reverse happens - data from the s3 location is downloaded to this path before the algorithm is started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain four keys: 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries (default: ``None``). profiler_rule_configs (list[dict]): A list of profiler rule configurations. profiler_config(dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. * max_retry_attsmpts (int): Number of times a job should be retried. The key in RetryStrategy is 'MaxRetryAttempts'. Returns: Dict: a training request dict """ train_request = { "AlgorithmSpecification": {"TrainingInputMode": input_mode}, "OutputDataConfig": output_config, "TrainingJobName": job_name, "StoppingCondition": stop_condition, "ResourceConfig": resource_config, "RoleArn": role, } if image_uri and algorithm_arn: raise ValueError( "image_uri and algorithm_arn are mutually exclusive." "Both were provided: image_uri: %s algorithm_arn: %s" % (image_uri, algorithm_arn) ) if image_uri is None and algorithm_arn is None: raise ValueError("either image_uri or algorithm_arn is required. None was provided.") if image_uri is not None: train_request["AlgorithmSpecification"]["TrainingImage"] = image_uri if training_image_config is not None: train_request["AlgorithmSpecification"]["TrainingImageConfig"] = training_image_config if container_entry_point is not None: train_request["AlgorithmSpecification"]["ContainerEntrypoint"] = container_entry_point if container_arguments is not None: train_request["AlgorithmSpecification"]["ContainerArguments"] = container_arguments if algorithm_arn is not None: train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn if input_config is not None: train_request["InputDataConfig"] = input_config if metric_definitions is not None: train_request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions if enable_sagemaker_metrics is not None: train_request["AlgorithmSpecification"][ "EnableSageMakerMetricsTimeSeries" ] = enable_sagemaker_metrics if hyperparameters and len(hyperparameters) > 0: train_request["HyperParameters"] = hyperparameters if environment is not None: train_request["Environment"] = environment if tags is not None: train_request["Tags"] = tags if vpc_config is not None: train_request["VpcConfig"] = vpc_config if experiment_config and len(experiment_config) > 0: train_request["ExperimentConfig"] = experiment_config if enable_network_isolation: train_request["EnableNetworkIsolation"] = enable_network_isolation if encrypt_inter_container_traffic: train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic if use_spot_instances: # estimator.use_spot_instances may be a Pipeline ParameterBoolean object # which is parsed during the Pipeline execution runtime train_request["EnableManagedSpotTraining"] = use_spot_instances if checkpoint_s3_uri: checkpoint_config = {"S3Uri": checkpoint_s3_uri} if checkpoint_local_path: checkpoint_config["LocalPath"] = checkpoint_local_path train_request["CheckpointConfig"] = checkpoint_config if debugger_rule_configs is not None: train_request["DebugRuleConfigurations"] = debugger_rule_configs if debugger_hook_config is not None: train_request["DebugHookConfig"] = debugger_hook_config if tensorboard_output_config is not None: train_request["TensorBoardOutputConfig"] = tensorboard_output_config if profiler_rule_configs is not None: train_request["ProfilerRuleConfigurations"] = profiler_rule_configs if profiler_config is not None: train_request["ProfilerConfig"] = profiler_config if retry_strategy is not None: train_request["RetryStrategy"] = retry_strategy return train_request def update_training_job( self, job_name, profiler_rule_configs=None, profiler_config=None, resource_config=None, ): """Calls the UpdateTrainingJob API for the given job name and returns the response. Args: job_name (str): Name of the training job being updated. profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``). profiler_config(dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). resource_config (dict): Configuration of the resources for the training job. You can update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: ``None``). """ # No injections from sagemaker_config because the UpdateTrainingJob API's resource_config # object accepts fewer parameters than the CreateTrainingJob API, and none that the # sagemaker_config currently supports inferred_profiler_config = update_nested_dictionary_with_values_from_config( profiler_config, TRAINING_JOB_PROFILE_CONFIG_PATH, sagemaker_session=self ) update_training_job_request = self._get_update_training_job_request( job_name=job_name, profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, resource_config=resource_config, ) LOGGER.info("Updating training job with name %s", job_name) LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4)) self.sagemaker_client.update_training_job(**update_training_job_request) def _get_update_training_job_request( self, job_name, profiler_rule_configs=None, profiler_config=None, resource_config=None, ): """Constructs a request compatible for updating an Amazon SageMaker training job. Args: job_name (str): Name of the training job being updated. profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``). profiler_config(dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). resource_config (dict): Configuration of the resources for the training job. You can update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: ``None``). Returns: Dict: an update training request dict """ update_training_job_request = { "TrainingJobName": job_name, } if profiler_rule_configs is not None: update_training_job_request["ProfilerRuleConfigurations"] = profiler_rule_configs if profiler_config is not None: update_training_job_request["ProfilerConfig"] = profiler_config if resource_config is not None: update_training_job_request["ResourceConfig"] = resource_config return update_training_job_request def process( self, inputs, output_config, job_name, resources, stopping_condition, app_specification, environment: Optional[Dict[str, str]] = None, network_config=None, role_arn=None, tags=None, experiment_config=None, ): """Create an Amazon SageMaker processing job. Args: inputs ([dict]): List of up to 10 ProcessingInput dictionaries. output_config (dict): A config dictionary, which contains a list of up to 10 ProcessingOutput dictionaries, as well as an optional KMS key ID. job_name (str): The name of the processing job. The name must be unique within an AWS Region in an AWS account. Names should have minimum length of 1 and maximum length of 63 characters. resources (dict): Encapsulates the resources, including ML instances and storage, to use for the processing job. stopping_condition (dict[str,int]): Specifies a limit to how long the processing job can run, in seconds. app_specification (dict[str,str]): Configures the processing job to run the given image. Details are in the processing container specification. environment (dict): Environment variables to start the processing container with. network_config (dict): Specifies networking options, such as network traffic encryption between processing containers, whether to allow inbound and outbound network calls to and from processing containers, and VPC subnets and security groups to use for VPC-enabled processing jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, PROCESSING_JOB, TAGS) ) network_config = resolve_nested_dict_value_from_config( network_config, ["EnableInterContainerTrafficEncryption"], PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, sagemaker_session=self, ) # Processing Input can either have AthenaDatasetDefinition or RedshiftDatasetDefinition # or neither, but not both union_key_paths_for_dataset_definition = [ [ "DatasetDefinition", "S3Input", ], [ "DatasetDefinition.AthenaDatasetDefinition", "DatasetDefinition.RedshiftDatasetDefinition", ], ] update_list_of_dicts_with_values_from_config( inputs, PROCESSING_JOB_INPUTS_PATH, union_key_paths=union_key_paths_for_dataset_definition, sagemaker_session=self, ) role_arn = resolve_value_from_config( role_arn, PROCESSING_JOB_ROLE_ARN_PATH, sagemaker_session=self ) inferred_network_config_from_config = update_nested_dictionary_with_values_from_config( network_config, PROCESSING_JOB_NETWORK_CONFIG_PATH, sagemaker_session=self ) inferred_output_config = update_nested_dictionary_with_values_from_config( output_config, PROCESSING_OUTPUT_CONFIG_PATH, sagemaker_session=self ) inferred_resources_config = update_nested_dictionary_with_values_from_config( resources, PROCESSING_JOB_PROCESSING_RESOURCES_PATH, sagemaker_session=self ) environment = resolve_value_from_config( direct_input=environment, config_path=PROCESSING_JOB_ENVIRONMENT_PATH, default_value=None, sagemaker_session=self, ) process_request = self._get_process_request( inputs=inputs, output_config=inferred_output_config, job_name=job_name, resources=inferred_resources_config, stopping_condition=stopping_condition, app_specification=app_specification, environment=environment, network_config=inferred_network_config_from_config, role_arn=role_arn, tags=tags, experiment_config=experiment_config, ) def submit(request): LOGGER.info("Creating processing-job with name %s", job_name) LOGGER.debug("process request: %s", json.dumps(request, indent=4)) self.sagemaker_client.create_processing_job(**request) self._intercept_create_request(process_request, submit, self.process.__name__) def _get_process_request( self, inputs, output_config, job_name, resources, stopping_condition, app_specification, environment, network_config, role_arn, tags, experiment_config=None, ): """Constructs a request compatible for an Amazon SageMaker processing job. Args: inputs ([dict]): List of up to 10 ProcessingInput dictionaries. output_config (dict): A config dictionary, which contains a list of up to 10 ProcessingOutput dictionaries, as well as an optional KMS key ID. job_name (str): The name of the processing job. The name must be unique within an AWS Region in an AWS account. Names should have minimum length of 1 and maximum length of 63 characters. resources (dict): Encapsulates the resources, including ML instances and storage, to use for the processing job. stopping_condition (dict[str,int]): Specifies a limit to how long the processing job can run, in seconds. app_specification (dict[str,str]): Configures the processing job to run the given image. Details are in the processing container specification. environment (dict): Environment variables to start the processing container with. network_config (dict): Specifies networking options, such as network traffic encryption between processing containers, whether to allow inbound and outbound network calls to and from processing containers, and VPC subnets and security groups to use for VPC-enabled processing jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. Returns: Dict: a processing job request dict """ process_request = { "ProcessingJobName": job_name, "ProcessingResources": resources, "AppSpecification": app_specification, "RoleArn": role_arn, } if inputs: process_request["ProcessingInputs"] = inputs if output_config["Outputs"]: process_request["ProcessingOutputConfig"] = output_config if environment is not None: process_request["Environment"] = environment if network_config is not None: process_request["NetworkConfig"] = network_config if stopping_condition is not None: process_request["StoppingCondition"] = stopping_condition if tags is not None: process_request["Tags"] = tags if experiment_config: process_request["ExperimentConfig"] = experiment_config return process_request def create_monitoring_schedule( self, monitoring_schedule_name, schedule_expression, statistics_s3_uri, constraints_s3_uri, monitoring_inputs, monitoring_output_config, instance_count, instance_type, volume_size_in_gb, volume_kms_key=None, image_uri=None, entrypoint=None, arguments=None, record_preprocessor_source_uri=None, post_analytics_processor_source_uri=None, max_runtime_in_seconds=None, environment=None, network_config=None, role_arn=None, tags=None, ): """Create an Amazon SageMaker monitoring schedule. Args: monitoring_schedule_name (str): The name of the monitoring schedule. The name must be unique within an AWS Region in an AWS account. Names should have a minimum length of 1 and a maximum length of 63 characters. schedule_expression (str): The cron expression that dictates the monitoring execution schedule. statistics_s3_uri (str): The S3 uri of the statistics file to use. constraints_s3_uri (str): The S3 uri of the constraints file to use. monitoring_inputs ([dict]): List of MonitoringInput dictionaries. monitoring_output_config (dict): A config dictionary, which contains a list of MonitoringOutput dictionaries, as well as an optional KMS key ID. instance_count (int): The number of instances to run. instance_type (str): The type of instance to run. volume_size_in_gb (int): Size of the volume in GB. volume_kms_key (str): KMS key to use when encrypting the volume. image_uri (str): The image uri to use for monitoring executions. entrypoint (str): The entrypoint to the monitoring execution image. arguments (str): The arguments to pass to the monitoring execution image. record_preprocessor_source_uri (str or None): The S3 uri that points to the script that pre-processes the dataset (only applicable to first-party images). post_analytics_processor_source_uri (str or None): The S3 uri that points to the script that post-processes the dataset (only applicable to first-party images). max_runtime_in_seconds (int): Specifies a limit to how long the processing job can run, in seconds. environment (dict): Environment variables to start the monitoring execution container with. network_config (dict): Specifies networking options, such as network traffic encryption between processing containers, whether to allow inbound and outbound network calls to and from processing containers, and VPC subnets and security groups to use for VPC-enabled processing jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. """ role_arn = resolve_value_from_config( role_arn, MONITORING_JOB_ROLE_ARN_PATH, sagemaker_session=self ) volume_kms_key = resolve_value_from_config( volume_kms_key, MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH, sagemaker_session=self ) inferred_network_config_from_config = update_nested_dictionary_with_values_from_config( network_config, MONITORING_JOB_NETWORK_CONFIG_PATH, sagemaker_session=self ) environment = resolve_value_from_config( direct_input=environment, config_path=MONITORING_JOB_ENVIRONMENT_PATH, default_value=None, sagemaker_session=self, ) monitoring_schedule_request = { "MonitoringScheduleName": monitoring_schedule_name, "MonitoringScheduleConfig": { "MonitoringJobDefinition": { "Environment": environment, "MonitoringInputs": monitoring_inputs, "MonitoringResources": { "ClusterConfig": { "InstanceCount": instance_count, "InstanceType": instance_type, "VolumeSizeInGB": volume_size_in_gb, } }, "MonitoringAppSpecification": {"ImageUri": image_uri}, "RoleArn": role_arn, } }, } if schedule_expression is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["ScheduleConfig"] = { "ScheduleExpression": schedule_expression } if monitoring_output_config is not None: kms_key_from_config = resolve_value_from_config( config_path=MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, sagemaker_session=self ) if KMS_KEY_ID not in monitoring_output_config and kms_key_from_config: monitoring_output_config[KMS_KEY_ID] = kms_key_from_config monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringOutputConfig" ] = monitoring_output_config if statistics_s3_uri is not None or constraints_s3_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ] = {} if statistics_s3_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ]["StatisticsResource"] = {"S3Uri": statistics_s3_uri} if constraints_s3_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ]["ConstraintsResource"] = {"S3Uri": constraints_s3_uri} if record_preprocessor_source_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["RecordPreprocessorSourceUri"] = record_preprocessor_source_uri if post_analytics_processor_source_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["PostAnalyticsProcessorSourceUri"] = post_analytics_processor_source_uri if entrypoint is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["ContainerEntrypoint"] = entrypoint if arguments is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["ContainerArguments"] = arguments if volume_kms_key is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringResources" ]["ClusterConfig"]["VolumeKmsKeyId"] = volume_kms_key if max_runtime_in_seconds is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "StoppingCondition" ] = {"MaxRuntimeInSeconds": max_runtime_in_seconds} if environment is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "Environment" ] = environment if inferred_network_config_from_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "NetworkConfig" ] = inferred_network_config_from_config tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, MONITORING_SCHEDULE, TAGS) ) if tags is not None: monitoring_schedule_request["Tags"] = tags LOGGER.info("Creating monitoring schedule name %s.", monitoring_schedule_name) LOGGER.debug( "monitoring_schedule_request= %s", json.dumps(monitoring_schedule_request, indent=4) ) self.sagemaker_client.create_monitoring_schedule(**monitoring_schedule_request) def update_monitoring_schedule( self, monitoring_schedule_name, schedule_expression=None, statistics_s3_uri=None, constraints_s3_uri=None, monitoring_inputs=None, monitoring_output_config=None, instance_count=None, instance_type=None, volume_size_in_gb=None, volume_kms_key=None, image_uri=None, entrypoint=None, arguments=None, record_preprocessor_source_uri=None, post_analytics_processor_source_uri=None, max_runtime_in_seconds=None, environment=None, network_config=None, role_arn=None, ): """Update an Amazon SageMaker monitoring schedule. Args: monitoring_schedule_name (str): The name of the monitoring schedule. The name must be unique within an AWS Region in an AWS account. Names should have a minimum length of 1 and a maximum length of 63 characters. schedule_expression (str): The cron expression that dictates the monitoring execution schedule. statistics_s3_uri (str): The S3 uri of the statistics file to use. constraints_s3_uri (str): The S3 uri of the constraints file to use. monitoring_inputs ([dict]): List of MonitoringInput dictionaries. monitoring_output_config (dict): A config dictionary, which contains a list of MonitoringOutput dictionaries, as well as an optional KMS key ID. instance_count (int): The number of instances to run. instance_type (str): The type of instance to run. volume_size_in_gb (int): Size of the volume in GB. volume_kms_key (str): KMS key to use when encrypting the volume. image_uri (str): The image uri to use for monitoring executions. entrypoint (str): The entrypoint to the monitoring execution image. arguments (str): The arguments to pass to the monitoring execution image. record_preprocessor_source_uri (str or None): The S3 uri that points to the script that pre-processes the dataset (only applicable to first-party images). post_analytics_processor_source_uri (str or None): The S3 uri that points to the script that post-processes the dataset (only applicable to first-party images). max_runtime_in_seconds (int): Specifies a limit to how long the processing job can run, in seconds. environment (dict): Environment variables to start the monitoring execution container with. network_config (dict): Specifies networking options, such as network traffic encryption between processing containers, whether to allow inbound and outbound network calls to and from processing containers, and VPC subnets and security groups to use for VPC-enabled processing jobs. role_arn (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. """ existing_desc = self.sagemaker_client.describe_monitoring_schedule( MonitoringScheduleName=monitoring_schedule_name ) existing_schedule_config = None if ( existing_desc.get("MonitoringScheduleConfig") is not None and existing_desc["MonitoringScheduleConfig"].get("ScheduleConfig") is not None and existing_desc["MonitoringScheduleConfig"]["ScheduleConfig"]["ScheduleExpression"] is not None ): existing_schedule_config = existing_desc["MonitoringScheduleConfig"]["ScheduleConfig"][ "ScheduleExpression" ] request_schedule_expression = schedule_expression or existing_schedule_config request_monitoring_inputs = ( monitoring_inputs or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringInputs" ] ) request_instance_count = ( instance_count or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringResources" ]["ClusterConfig"]["InstanceCount"] ) request_instance_type = ( instance_type or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringResources" ]["ClusterConfig"]["InstanceType"] ) request_volume_size_in_gb = ( volume_size_in_gb or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringResources" ]["ClusterConfig"]["VolumeSizeInGB"] ) request_image_uri = ( image_uri or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["ImageUri"] ) request_role_arn = ( role_arn or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["RoleArn"] ) monitoring_schedule_request = { "MonitoringScheduleName": monitoring_schedule_name, "MonitoringScheduleConfig": { "MonitoringJobDefinition": { "MonitoringInputs": request_monitoring_inputs, "MonitoringResources": { "ClusterConfig": { "InstanceCount": request_instance_count, "InstanceType": request_instance_type, "VolumeSizeInGB": request_volume_size_in_gb, } }, "MonitoringAppSpecification": {"ImageUri": request_image_uri}, "RoleArn": request_role_arn, } }, } if existing_schedule_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["ScheduleConfig"] = { "ScheduleExpression": request_schedule_expression } existing_monitoring_output_config = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ].get("MonitoringOutputConfig") if monitoring_output_config is not None or existing_monitoring_output_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringOutputConfig" ] = (monitoring_output_config or existing_monitoring_output_config) existing_statistics_s3_uri = None existing_constraints_s3_uri = None if ( existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"].get( "BaselineConfig" ) is not None ): if ( existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ].get("StatisticsResource") is not None ): existing_statistics_s3_uri = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["BaselineConfig"]["StatisticsResource"]["S3Uri"] if ( existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ].get("ConstraintsResource") is not None ): existing_statistics_s3_uri = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["BaselineConfig"]["ConstraintsResource"]["S3Uri"] if ( statistics_s3_uri is not None or constraints_s3_uri is not None or existing_statistics_s3_uri is not None or existing_constraints_s3_uri is not None ): monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ] = {} if statistics_s3_uri is not None or existing_statistics_s3_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ]["StatisticsResource"] = {"S3Uri": statistics_s3_uri or existing_statistics_s3_uri} if constraints_s3_uri is not None or existing_constraints_s3_uri is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "BaselineConfig" ]["ConstraintsResource"] = {"S3Uri": constraints_s3_uri or existing_constraints_s3_uri} existing_record_preprocessor_source_uri = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["MonitoringAppSpecification"].get("RecordPreprocessorSourceUri") if ( record_preprocessor_source_uri is not None or existing_record_preprocessor_source_uri is not None ): monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["RecordPreprocessorSourceUri"] = ( record_preprocessor_source_uri or existing_record_preprocessor_source_uri ) existing_post_analytics_processor_source_uri = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["MonitoringAppSpecification"].get("PostAnalyticsProcessorSourceUri") if ( post_analytics_processor_source_uri is not None or existing_post_analytics_processor_source_uri is not None ): monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["PostAnalyticsProcessorSourceUri"] = ( post_analytics_processor_source_uri or existing_post_analytics_processor_source_uri ) existing_entrypoint = existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ].get("ContainerEntrypoint") if entrypoint is not None or existing_entrypoint is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["ContainerEntrypoint"] = (entrypoint or existing_entrypoint) existing_arguments = existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ].get("ContainerArguments") if arguments is not None or existing_arguments is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringAppSpecification" ]["ContainerArguments"] = (arguments or existing_arguments) existing_volume_kms_key = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["MonitoringResources"]["ClusterConfig"].get("VolumeKmsKeyId") if volume_kms_key is not None or existing_volume_kms_key is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "MonitoringResources" ]["ClusterConfig"]["VolumeKmsKeyId"] = (volume_kms_key or existing_volume_kms_key) existing_max_runtime_in_seconds = None if existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"].get( "StoppingCondition" ): existing_max_runtime_in_seconds = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ]["StoppingCondition"].get("MaxRuntimeInSeconds") if max_runtime_in_seconds is not None or existing_max_runtime_in_seconds is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "StoppingCondition" ] = {"MaxRuntimeInSeconds": max_runtime_in_seconds or existing_max_runtime_in_seconds} existing_environment = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ].get("Environment") if environment is not None or existing_environment is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "Environment" ] = (environment or existing_environment) existing_network_config = existing_desc["MonitoringScheduleConfig"][ "MonitoringJobDefinition" ].get("NetworkConfig") _network_config = network_config or existing_network_config _network_config = resolve_nested_dict_value_from_config( _network_config, ["EnableInterContainerTrafficEncryption"], MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, sagemaker_session=self, ) if _network_config is not None: monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ "NetworkConfig" ] = _network_config LOGGER.info("Updating monitoring schedule with name: %s .", monitoring_schedule_name) LOGGER.debug( "monitoring_schedule_request= %s", json.dumps(monitoring_schedule_request, indent=4) ) self.sagemaker_client.update_monitoring_schedule(**monitoring_schedule_request) def start_monitoring_schedule(self, monitoring_schedule_name): """Starts a monitoring schedule. Args: monitoring_schedule_name (str): The name of the Amazon SageMaker Monitoring Schedule to start. """ print() print("Starting Monitoring Schedule with name: {}".format(monitoring_schedule_name)) self.sagemaker_client.start_monitoring_schedule( MonitoringScheduleName=monitoring_schedule_name ) def stop_monitoring_schedule(self, monitoring_schedule_name): """Stops a monitoring schedule. Args: monitoring_schedule_name (str): The name of the Amazon SageMaker Monitoring Schedule to stop. """ print() print("Stopping Monitoring Schedule with name: {}".format(monitoring_schedule_name)) self.sagemaker_client.stop_monitoring_schedule( MonitoringScheduleName=monitoring_schedule_name ) def delete_monitoring_schedule(self, monitoring_schedule_name): """Deletes a monitoring schedule. Args: monitoring_schedule_name (str): The name of the Amazon SageMaker Monitoring Schedule to delete. """ print() print("Deleting Monitoring Schedule with name: {}".format(monitoring_schedule_name)) self.sagemaker_client.delete_monitoring_schedule( MonitoringScheduleName=monitoring_schedule_name ) def describe_monitoring_schedule(self, monitoring_schedule_name): """Calls the DescribeMonitoringSchedule API for given name and returns the response. Args: monitoring_schedule_name (str): The name of the processing job to describe. Returns: dict: A dictionary response with the processing job description. """ return self.sagemaker_client.describe_monitoring_schedule( MonitoringScheduleName=monitoring_schedule_name ) def list_monitoring_executions( self, monitoring_schedule_name, sort_by="ScheduledTime", sort_order="Descending", max_results=100, ): """Lists the monitoring executions associated with the given monitoring_schedule_name. Args: monitoring_schedule_name (str): The monitoring_schedule_name for which to retrieve the monitoring executions. sort_by (str): The field to sort by. Can be one of: "CreationTime", "ScheduledTime", "Status". Default: "ScheduledTime". sort_order (str): The sort order. Can be one of: "Ascending", "Descending". Default: "Descending". max_results (int): The maximum number of results to return. Must be between 1 and 100. Returns: dict: Dictionary of monitoring schedule executions. """ response = self.sagemaker_client.list_monitoring_executions( MonitoringScheduleName=monitoring_schedule_name, SortBy=sort_by, SortOrder=sort_order, MaxResults=max_results, ) return response def list_monitoring_schedules( self, endpoint_name=None, sort_by="CreationTime", sort_order="Descending", max_results=100 ): """Lists the monitoring executions associated with the given monitoring_schedule_name. Args: endpoint_name (str): The name of the endpoint to filter on. If not provided, does not filter on it. Default: None. sort_by (str): The field to sort by. Can be one of: "Name", "CreationTime", "Status". Default: "CreationTime". sort_order (str): The sort order. Can be one of: "Ascending", "Descending". Default: "Descending". max_results (int): The maximum number of results to return. Must be between 1 and 100. Returns: dict: Dictionary of monitoring schedule executions. """ if endpoint_name is not None: response = self.sagemaker_client.list_monitoring_schedules( EndpointName=endpoint_name, SortBy=sort_by, SortOrder=sort_order, MaxResults=max_results, ) else: response = self.sagemaker_client.list_monitoring_schedules( SortBy=sort_by, SortOrder=sort_order, MaxResults=max_results ) return response def update_monitoring_alert( self, monitoring_schedule_name: str, monitoring_alert_name: str, data_points_to_alert: int, evaluation_period: int, ): """Update the monitoring alerts associated with the given schedule_name and alert_name Args: monitoring_schedule_name (str): The name of the monitoring schedule to update. monitoring_alert_name (str): The name of the monitoring alert to update. data_points_to_alert (int): The data point to alert. evaluation_period (int): The period to evaluate the alert status. Returns: dict: A dict represents the update alert response. """ return self.sagemaker_client.update_monitoring_alert( MonitoringScheduleName=monitoring_schedule_name, MonitoringAlertName=monitoring_alert_name, DatapointsToAlert=data_points_to_alert, EvaluationPeriod=evaluation_period, ) def list_monitoring_alerts( self, monitoring_schedule_name: str, next_token: Optional[str] = None, max_results: Optional[int] = 10, ) -> Dict: """Lists the monitoring alerts associated with the given monitoring_schedule_name. Args: monitoring_schedule_name (str): The name of the monitoring schedule to filter on. If not provided, does not filter on it. next_token (Optional[str]): The pagination token. Default: None max_results (Optional[int]): The maximum number of results to return. Must be between 1 and 100. Default: 10 Returns: dict: list of monitoring alerts. """ params = { "MonitoringScheduleName": monitoring_schedule_name, "MaxResults": max_results, } if next_token: params.update({"NextToken": next_token}) return self.sagemaker_client.list_monitoring_alerts(**params) def list_monitoring_alert_history( self, monitoring_schedule_name: Optional[str] = None, monitoring_alert_name: Optional[str] = None, sort_by: Optional[str] = "CreationTime", sort_order: Optional[str] = "Descending", next_token: Optional[str] = None, max_results: Optional[int] = 10, creation_time_before: Optional[str] = None, creation_time_after: Optional[str] = None, status_equals: Optional[str] = None, ) -> Dict: """Lists the alert history associated with the given schedule_name and alert_name. Args: monitoring_schedule_name (Optional[str]): The name of the monitoring_schedule_name to filter on. If not provided, does not filter on it. Default: None. monitoring_alert_name (Optional[str]): The name of the monitoring_alert_name to filter on. If not provided, does not filter on it. Default: None. sort_by (Optional[str]): sort_by (str): The field to sort by. Can be one of: "Name", "CreationTime" Default: "CreationTime". sort_order (Optional[str]): The sort order. Can be one of: "Ascending", "Descending". Default: "Descending". next_token (Optional[str]): The pagination token. Default: None max_results (Optional[int]): The maximum number of results to return. Must be between 1 and 100. Default: 10. creation_time_before (Optional[str]): A filter to filter alert history before a time creation_time_after (Optional[str]): A filter to filter alert history after a time Default: None. status_equals (Optional[str]): A filter to filter alert history by status Default: None. Returns: dict: list of monitoring alert history. """ params = { "MonitoringScheduleName": monitoring_schedule_name, "SortBy": sort_by, "SortOrder": sort_order, "MaxResults": max_results, } if monitoring_alert_name: params.update({"MonitoringAlertName": monitoring_alert_name}) if creation_time_before: params.update({"CreationTimeBefore": creation_time_before}) if creation_time_after: params.update({"CreationTimeAfter": creation_time_after}) if status_equals: params.update({"StatusEquals": status_equals}) if next_token: params.update({"NextToken": next_token}) return self.sagemaker_client.list_monitoring_alert_history(**params) def was_processing_job_successful(self, job_name): """Calls the DescribeProcessingJob API for the given job name. It returns True if job was successful. Args: job_name (str): The name of the processing job to describe. Returns: bool: Whether the processing job was successful. """ job_desc = self.sagemaker_client.describe_processing_job(ProcessingJobName=job_name) return job_desc["ProcessingJobStatus"] == "Completed" def describe_processing_job(self, job_name): """Calls the DescribeProcessingJob API for the given job name and returns the response. Args: job_name (str): The name of the processing job to describe. Returns: dict: A dictionary response with the processing job description. """ return self.sagemaker_client.describe_processing_job(ProcessingJobName=job_name) def stop_processing_job(self, job_name): """Calls the StopProcessingJob API for the given job name. Args: job_name (str): The name of the processing job to stop. """ self.sagemaker_client.stop_processing_job(ProcessingJobName=job_name) def stop_training_job(self, job_name): """Calls the StopTrainingJob API for the given job name. Args: job_name (str): The name of the training job to stop. """ self.sagemaker_client.stop_training_job(TrainingJobName=job_name) def describe_training_job(self, job_name): """Calls the DescribeTrainingJob API for the given job name and returns the response. Args: job_name (str): The name of the training job to describe. Returns: dict: A dictionary response with the training job description. """ return self.sagemaker_client.describe_training_job(TrainingJobName=job_name) def auto_ml( self, input_config, output_config, auto_ml_job_config, role=None, job_name=None, problem_type=None, job_objective=None, generate_candidate_definitions_only=False, tags=None, model_deploy_config=None, ): """Create an Amazon SageMaker AutoML job. Args: input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource" and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are optional fields. output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. auto_ml_job_config (dict): A dict of AutoMLJob config, containing "StoppingCondition", "SecurityConfig", optionally contains "VolumeKmsKeyId". role (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob should have a unique job name. problem_type (str): The type of problem of this AutoMLJob. Valid values are "Regression", "BinaryClassification", "MultiClassClassification". If None, SageMaker AutoMLJob will infer the problem type automatically. job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional), "MetricName" and "Value". generate_candidate_definitions_only (bool): Indicates whether to only generate candidate definitions. If True, AutoML.list_candidates() cannot be called. Default: False. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. model_deploy_config (dict): Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment. Contains "AutoGenerateEndpointName" and "EndpointName" """ role = resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH, sagemaker_session=self) inferred_output_config = update_nested_dictionary_with_values_from_config( output_config, AUTO_ML_OUTPUT_CONFIG_PATH, sagemaker_session=self ) inferred_automl_job_config = update_nested_dictionary_with_values_from_config( auto_ml_job_config, AUTO_ML_JOB_CONFIG_PATH, sagemaker_session=self ) auto_ml_job_request = self._get_auto_ml_request( input_config=input_config, output_config=inferred_output_config, auto_ml_job_config=inferred_automl_job_config, role=role, job_name=job_name, problem_type=problem_type, job_objective=job_objective, generate_candidate_definitions_only=generate_candidate_definitions_only, tags=tags, model_deploy_config=model_deploy_config, ) def submit(request): LOGGER.info("Creating auto-ml-job with name: %s", job_name) LOGGER.debug("auto ml request: %s", json.dumps(request), indent=4) self.sagemaker_client.create_auto_ml_job(**request) self._intercept_create_request(auto_ml_job_request, submit, self.auto_ml.__name__) def _get_auto_ml_request( self, input_config, output_config, auto_ml_job_config, role, job_name, problem_type=None, job_objective=None, generate_candidate_definitions_only=False, tags=None, model_deploy_config=None, ): """Constructs a request compatible for creating an Amazon SageMaker AutoML job. Args: input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource" and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are optional fields. output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. auto_ml_job_config (dict): A dict of AutoMLJob config, containing "StoppingCondition", "SecurityConfig", optionally contains "VolumeKmsKeyId". role (str): The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob should have a unique job name. problem_type (str): The type of problem of this AutoMLJob. Valid values are "Regression", "BinaryClassification", "MultiClassClassification". If None, SageMaker AutoMLJob will infer the problem type automatically. job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional), "MetricName" and "Value". generate_candidate_definitions_only (bool): Indicates whether to only generate candidate definitions. If True, AutoML.list_candidates() cannot be called. Default: False. tags ([dict[str,str]]): A list of dictionaries containing key-value pairs. model_deploy_config (dict): Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment. Contains "AutoGenerateEndpointName" and "EndpointName" Returns: Dict: a automl request dict """ auto_ml_job_request = { "AutoMLJobName": job_name, "InputDataConfig": input_config, "OutputDataConfig": output_config, "AutoMLJobConfig": auto_ml_job_config, "RoleArn": role, "GenerateCandidateDefinitionsOnly": generate_candidate_definitions_only, } if model_deploy_config is not None: auto_ml_job_request["ModelDeployConfig"] = model_deploy_config if job_objective is not None: auto_ml_job_request["AutoMLJobObjective"] = job_objective if problem_type is not None: auto_ml_job_request["ProblemType"] = problem_type tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB, TAGS) ) if tags is not None: auto_ml_job_request["Tags"] = tags return auto_ml_job_request def describe_auto_ml_job(self, job_name): """Calls the DescribeAutoMLJob API for the given job name and returns the response. Args: job_name (str): The name of the AutoML job to describe. Returns: dict: A dictionary response with the AutoML Job description. """ return self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name) def list_candidates( self, job_name, status_equals=None, candidate_name=None, candidate_arn=None, sort_order=None, sort_by=None, max_results=None, ): """Returns the list of candidates of an AutoML job for a given name. Args: job_name (str): The name of the AutoML job. If None, will use object's latest_auto_ml_job name. status_equals (str): Filter the result with candidate status, values could be "Completed", "InProgress", "Failed", "Stopped", "Stopping" candidate_name (str): The name of a specified candidate to list. Default to None. candidate_arn (str): The Arn of a specified candidate to list. Default to None. sort_order (str): The order that the candidates will be listed in result. Default to None. sort_by (str): The value that the candidates will be sorted by. Default to None. max_results (int): The number of candidates will be listed in results, between 1 to 100. Default to None. If None, will return all the candidates. Returns: list: A list of dictionaries with candidates information """ list_candidates_args = {"AutoMLJobName": job_name} if status_equals: list_candidates_args["StatusEquals"] = status_equals if candidate_name: list_candidates_args["CandidateNameEquals"] = candidate_name if candidate_arn: list_candidates_args["CandidateArnEquals"] = candidate_arn if sort_order: list_candidates_args["SortOrder"] = sort_order if sort_by: list_candidates_args["SortBy"] = sort_by if max_results: list_candidates_args["MaxResults"] = max_results return self.sagemaker_client.list_candidates_for_auto_ml_job(**list_candidates_args) def wait_for_auto_ml_job(self, job, poll=5): """Wait for an Amazon SageMaker AutoML job to complete. Args: job (str): Name of the auto ml job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeAutoMLJob`` API. Raises: exceptions.CapacityError: If the auto ml job fails with CapacityError. exceptions.UnexpectedStatusException: If the auto ml job fails. """ desc = _wait_until(lambda: _auto_ml_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "AutoMLJobStatus") return desc def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this method self, job_name, wait=False, poll=10 ): """Display logs for a given AutoML job, optionally tailing them until job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. Args: job_name (str): Name of the Auto ML job to display the logs for. wait (bool): Whether to keep looking for new log entries until the job completes (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). Raises: exceptions.CapacityError: If waiting and auto ml job fails with CapacityError. exceptions.UnexpectedStatusException: If waiting and auto ml job fails. """ description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self.boto_session, description, job="AutoML" ) state = _get_initial_job_state(description, "AutoMLJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were # called with wait == False, we never check the job status. # # If wait == TRUE and job is not completed, the initial state is TAILING # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is # complete). # # The state table: # # STATE ACTIONS CONDITION NEW STATE # ---------------- ---------------- ----------------- ---------------- # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE # Else TAILING # JOB_COMPLETE Read logs, Pause Any COMPLETE # COMPLETE Read logs, Exit N/A # # Notes: # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to # Cloudwatch after the job was marked complete. last_describe_job_call = time.time() while True: _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap, ) if state == LogState.COMPLETE: break time.sleep(poll) if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name) last_describe_job_call = time.time() status = description["AutoMLJobStatus"] if status in ("Completed", "Failed", "Stopped"): print() state = LogState.JOB_COMPLETE if wait: _check_job_status(job_name, description, "AutoMLJobStatus") if dot: print() def compile_model( self, input_model_config, output_model_config, role=None, job_name=None, stop_condition=None, tags=None, ): """Create an Amazon SageMaker Neo compilation job. Args: input_model_config (dict): the trained model and the Amazon S3 location where it is stored. output_model_config (dict): Identifies the Amazon S3 location where you want Amazon SageMaker Neo to save the results of compilation job role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo compilation jobs use this role to access model artifacts. You must grant sufficient permissions to this role. job_name (str): Name of the compilation job being created. stop_condition (dict): Defines when compilation job shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. tags (list[dict]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: str: ARN of the compile model job, if it is created. """ role = resolve_value_from_config( role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self ) inferred_output_model_config = update_nested_dictionary_with_values_from_config( output_model_config, COMPILATION_JOB_OUTPUT_CONFIG_PATH, sagemaker_session=self ) vpc_config = resolve_value_from_config( config_path=COMPILATION_JOB_VPC_CONFIG_PATH, sagemaker_session=self ) compilation_job_request = { "InputConfig": input_model_config, "OutputConfig": inferred_output_model_config, "RoleArn": role, "StoppingCondition": stop_condition, "CompilationJobName": job_name, } if vpc_config: compilation_job_request["VpcConfig"] = vpc_config tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, COMPILATION_JOB, TAGS) ) if tags is not None: compilation_job_request["Tags"] = tags LOGGER.info("Creating compilation-job with name: %s", job_name) self.sagemaker_client.create_compilation_job(**compilation_job_request) def package_model_for_edge( self, output_model_config, role=None, job_name=None, compilation_job_name=None, model_name=None, model_version=None, resource_key=None, tags=None, ): """Create an Amazon SageMaker Edge packaging job. Args: output_model_config (dict): Identifies the Amazon S3 location where you want Amazon SageMaker Edge to save the results of edge packaging job role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Edge edge packaging jobs use this role to access model artifacts. You must grant sufficient permissions to this role. job_name (str): Name of the edge packaging job being created. compilation_job_name (str): Name of the compilation job being created. resource_key (str): KMS key to encrypt the disk used to package the job tags (list[dict]): List of tags for labeling a compile model job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ role = resolve_value_from_config(role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self) inferred_output_model_config = update_nested_dictionary_with_values_from_config( output_model_config, EDGE_PACKAGING_OUTPUT_CONFIG_PATH, sagemaker_session=self ) edge_packaging_job_request = { "OutputConfig": inferred_output_model_config, "RoleArn": role, "ModelName": model_name, "ModelVersion": model_version, "EdgePackagingJobName": job_name, "CompilationJobName": compilation_job_name, } resource_key = resolve_value_from_config( resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self ) tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, EDGE_PACKAGING_JOB, TAGS) ) if tags is not None: edge_packaging_job_request["Tags"] = tags if resource_key is not None: edge_packaging_job_request["ResourceKey"] = resource_key LOGGER.info("Creating edge-packaging-job with name: %s", job_name) self.sagemaker_client.create_edge_packaging_job(**edge_packaging_job_request) def tune( # noqa: C901 self, job_name, strategy, objective_type, objective_metric_name, max_jobs, max_parallel_jobs, parameter_ranges, static_hyperparameters, input_mode, metric_definitions, role, input_config, output_config, resource_config, stop_condition, tags, warm_start_config, max_runtime_in_seconds=None, strategy_config=None, completion_criteria_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, early_stopping_type="Off", encrypt_inter_container_traffic=False, vpc_config=None, use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, random_seed=None, environment=None, hpo_resource_config=None, autotune=False, auto_parameters=None, ): """Create an Amazon SageMaker hyperparameter tuning job. Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. strategy_config (dict): A configuration for the hyperparameter tuning job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. max_jobs (int): Maximum total number of training jobs to start for the hyperparameter tuning job. max_parallel_jobs (int): Maximum number of parallel training jobs to start. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. static_hyperparameters (dict): Hyperparameters for model training. These hyperparameters remain unchanged across all of the training jobs for the hyperparameter tuning job. The hyperparameters are made accessible as a dictionary for the training code on SageMaker. image_uri (str): Docker image URI containing training code. algorithm_arn (str): Resource ARN for training algorithm created on or subscribed from AWS Marketplace (default: None). input_mode (str): The input mode that the algorithm supports. Valid modes: * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a directory in the Docker container. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of downloading the entire dataset before training begins. metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for the regular expression used to extract the metric from the logs. This should be defined only for jobs that don't use an Amazon algorithm. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. input_config (list): A list of Channel objects. Each channel is a named input source. Please refer to the format details described: https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. resource_config (dict): Contains values for ResourceConfig: * instance_count (int): Number of EC2 instances to use for training. The key in resource_config is 'InstanceCount'. * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. stop_condition (dict): When training should finish, e.g. ``MaxRuntimeInSeconds``. tags (list[dict]): List of tags for labeling the tuning job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. warm_start_config (dict): Configuration defining the type of warm start and other required configurations. max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds that a training job launched by a hyperparameter tuning job can run. completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A configuration for the completion criteria. early_stopping_type (str): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be attempted. If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to. enable_network_isolation (bool): Specifies whether to isolate the training container (default: ``False``). encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training jobs started for this hyperparameter tuning job (default: ``False``). vpc_config (dict): Contains values for VpcConfig (default: None): * subnets (list[str]): List of subnet ids. The key in vpc_config is 'Subnets'. * security_group_ids (list[str]): List of security group ids. The key in vpc_config is 'SecurityGroupIds'. use_spot_instances (bool): whether to use spot instances for training. checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. (default: ``None``). checkpoint_local_path (str): The local path that the algorithm writes its checkpoints to. SageMaker will persist all files under this path to `checkpoint_s3_uri` continually during training. On job startup the reverse happens - data from the s3 location is downloaded to this path before the algorithm is started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). random_seed (int): An initial value used to initialize a pseudo-random number generator. Setting a random seed will make the hyperparameter tuning search strategies to produce more consistent configurations for the same tuning job. (default: ``None``). environment (dict[str, str]) : Environment variables to be set for use during training jobs (default: ``None``) hpo_resource_config (dict): The configuration for the hyperparameter tuning resources, including the compute instances and storage volumes, used for training jobs launched by the tuning job, where you must specify either instance_configs or instance_count + instance_type + volume_size: * instance_count (int): Number of EC2 instances to use for training. The key in resource_config is 'InstanceCount'. * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. * volume_size (int or PipelineVariable): The volume size in GB of the data to be processed for hyperparameter optimisation * instance_configs (List[InstanceConfig]): A list containing the configuration(s) for one or more resources for processing hyperparameter jobs. These resources include compute instances and storage volumes to use in model training jobs. * volume_kms_key_id: The AWS Key Management Service (AWS KMS) key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. autotune (bool): Whether the parameter ranges or other unset settings of a tuning job should be chosen automatically (default: False). auto_parameters (dict[str, str]): Dictionary of auto parameters. The keys are names of auto parameters and values are example values of auto parameters (default: ``None``). """ tune_request = { "HyperParameterTuningJobName": job_name, "HyperParameterTuningJobConfig": self._map_tuning_config( strategy=strategy, max_jobs=max_jobs, max_parallel_jobs=max_parallel_jobs, max_runtime_in_seconds=max_runtime_in_seconds, objective_type=objective_type, objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, random_seed=random_seed, strategy_config=strategy_config, completion_criteria_config=completion_criteria_config, auto_parameters=auto_parameters, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, role=role, input_mode=input_mode, image_uri=image_uri, algorithm_arn=algorithm_arn, metric_definitions=metric_definitions, input_config=input_config, output_config=output_config, resource_config=resource_config, hpo_resource_config=hpo_resource_config, vpc_config=vpc_config, stop_condition=stop_condition, enable_network_isolation=enable_network_isolation, encrypt_inter_container_traffic=encrypt_inter_container_traffic, use_spot_instances=use_spot_instances, checkpoint_s3_uri=checkpoint_s3_uri, checkpoint_local_path=checkpoint_local_path, environment=environment, ), } if warm_start_config is not None: tune_request["WarmStartConfig"] = warm_start_config if autotune: tune_request["Autotune"] = {"Mode": "Enabled"} tags = _append_project_tags(tags) if tags is not None: tune_request["Tags"] = tags LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4)) self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request) def create_tuning_job( self, job_name, tuning_config, training_config=None, training_config_list=None, warm_start_config=None, tags=None, autotune=False, ): """Create an Amazon SageMaker hyperparameter tuning job. This method supports creating tuning jobs with single or multiple training algorithms (estimators), while the ``tune()`` method above only supports creating tuning jobs with single training algorithm. Args: job_name (str): Name of the tuning job being created. tuning_config (dict): Configuration to launch the tuning job. training_config (dict): Configuration to launch training jobs under the tuning job using a single algorithm. training_config_list (list[dict]): A list of configurations to launch training jobs under the tuning job using one or multiple algorithms. Either training_config or training_config_list should be provided, but not both. warm_start_config (dict): Configuration defining the type of warm start and other required configurations. tags (list[dict]): List of tags for labeling the tuning job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. autotune (bool): Whether the parameter ranges or other unset settings of a tuning job should be chosen automatically. """ if training_config is None and training_config_list is None: raise ValueError("Either training_config or training_config_list should be provided.") if training_config is not None and training_config_list is not None: raise ValueError( "Only one of training_config and training_config_list should be provided." ) tune_request = self._get_tuning_request( job_name=job_name, tuning_config=tuning_config, training_config=training_config, training_config_list=training_config_list, warm_start_config=warm_start_config, tags=tags, autotune=autotune, ) def submit(request): LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) LOGGER.debug("tune request: %s", json.dumps(request, indent=4)) self.sagemaker_client.create_hyper_parameter_tuning_job(**request) self._intercept_create_request(tune_request, submit, self.create_tuning_job.__name__) def _get_tuning_request( self, job_name, tuning_config, training_config=None, training_config_list=None, warm_start_config=None, tags=None, autotune=False, ): """Construct CreateHyperParameterTuningJob request Args: job_name (str): Name of the tuning job being created. tuning_config (dict): Configuration to launch the tuning job. training_config (dict): Configuration to launch training jobs under the tuning job using a single algorithm. training_config_list (list[dict]): A list of configurations to launch training jobs under the tuning job using one or multiple algorithms. Either training_config or training_config_list should be provided, but not both. warm_start_config (dict): Configuration defining the type of warm start and other required configurations. tags (list[dict]): List of tags for labeling the tuning job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. autotune (bool): Whether the parameter ranges or other unset settings of a tuning job should be chosen automatically. Returns: dict: A dictionary for CreateHyperParameterTuningJob request """ tune_request = { "HyperParameterTuningJobName": job_name, "HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config), } if autotune: tune_request["Autotune"] = {"Mode": "Enabled"} if training_config is not None: tune_request["TrainingJobDefinition"] = self._map_training_config(**training_config) if training_config_list is not None: tune_request["TrainingJobDefinitions"] = [ self._map_training_config(**training_cfg) for training_cfg in training_config_list ] if warm_start_config is not None: tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(tags) if tags is not None: tune_request["Tags"] = tags return tune_request def describe_tuning_job(self, job_name): """Calls DescribeHyperParameterTuningJob API for the given job name, returns the response. Args: job_name (str): The name of the hyperparameter tuning job to describe. Returns: dict: A dictionary response with the hyperparameter tuning job description. """ return self.sagemaker_client.describe_hyper_parameter_tuning_job( HyperParameterTuningJobName=job_name ) @classmethod def _map_tuning_config( cls, strategy, max_jobs, max_parallel_jobs, max_runtime_in_seconds=None, early_stopping_type="Off", objective_type=None, objective_metric_name=None, parameter_ranges=None, random_seed=None, strategy_config=None, completion_criteria_config=None, auto_parameters=None, ): """Construct tuning job configuration dictionary. Args: strategy (str): Strategy to be used for hyperparameter estimations. max_jobs (int): Maximum total number of training jobs to start for the hyperparameter tuning job. max_parallel_jobs (int): Maximum number of parallel training jobs to start. max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds that a training job launched by a hyperparameter tuning job can run. early_stopping_type (str): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be attempted. If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. random_seed (int): An initial value used to initialize a pseudo-random number generator. Setting a random seed will make the hyperparameter tuning search strategies to produce more consistent configurations for the same tuning job. strategy_config (dict): A configuration for the hyperparameter tuning job optimisation strategy. completion_criteria_config (dict): A configuration for the completion criteria. auto_parameters (dict): Dictionary of auto parameters. The keys are names of auto parameters and valeus are example values of auto parameters. Returns: A dictionary of tuning job configuration. For format details, please refer to HyperParameterTuningJobConfig as described in https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job """ tuning_config = { "Strategy": strategy, "ResourceLimits": { "MaxNumberOfTrainingJobs": max_jobs, "MaxParallelTrainingJobs": max_parallel_jobs, }, "TrainingJobEarlyStoppingType": early_stopping_type, } if max_runtime_in_seconds is not None: tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds if random_seed is not None: tuning_config["RandomSeed"] = random_seed tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: tuning_config["HyperParameterTuningJobObjective"] = tuning_objective if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges if auto_parameters is not None: if parameter_ranges is None: tuning_config["ParameterRanges"] = {} tuning_config["ParameterRanges"]["AutoParameters"] = [ {"Name": name, "ValueHint": value} for name, value in auto_parameters.items() ] if strategy_config is not None: tuning_config["StrategyConfig"] = strategy_config if completion_criteria_config is not None: tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config return tuning_config @classmethod def _map_tuning_objective(cls, objective_type, objective_metric_name): """Construct a dictionary of tuning objective from the arguments. Args: objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. Returns: A dictionary of tuning objective. For format details, please refer to HyperParameterTuningJobObjective as described in https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job """ tuning_objective = None if objective_type is not None or objective_metric_name is not None: tuning_objective = {} if objective_type is not None: tuning_objective["Type"] = objective_type if objective_metric_name is not None: tuning_objective["MetricName"] = objective_metric_name return tuning_objective @classmethod def _map_training_config( cls, static_hyperparameters, input_mode, role, output_config, stop_condition, input_config=None, resource_config=None, hpo_resource_config=None, metric_definitions=None, image_uri=None, algorithm_arn=None, vpc_config=None, enable_network_isolation=False, encrypt_inter_container_traffic=False, estimator_name=None, objective_type=None, objective_metric_name=None, parameter_ranges=None, use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, max_retry_attempts=None, environment=None, auto_parameters=None, ): """Construct a dictionary of training job configuration from the arguments. Args: static_hyperparameters (dict): Hyperparameters for model training. These hyperparameters remain unchanged across all of the training jobs for the hyperparameter tuning job. The hyperparameters are made accessible as a dictionary for the training code on SageMaker. input_mode (str): The input mode that the algorithm supports. Valid modes: * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a directory in the Docker container. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of downloading the entire dataset before training begins. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID. resource_config (dict): Contains values for ResourceConfig: * instance_count (int): Number of EC2 instances to use for training. The key in resource_config is 'InstanceCount'. * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. stop_condition (dict): When training should finish, e.g. ``MaxRuntimeInSeconds``. input_config (list): A list of Channel objects. Each channel is a named input source. Please refer to the format details described: https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for the regular expression used to extract the metric from the logs. This should be defined only for jobs that don't use an Amazon algorithm. image_uri (str): Docker image URI containing training code. algorithm_arn (str): Resource ARN for training algorithm created or subscribed on AWS Marketplace vpc_config (dict): Contains values for VpcConfig (default: None): * subnets (list[str]): List of subnet ids. The key in vpc_config is 'Subnets'. * security_group_ids (list[str]): List of security group ids. The key in vpc_config is 'SecurityGroupIds'. enable_network_isolation (bool): Specifies whether to isolate the training container encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training jobs started for this hyperparameter tuning job (default: ``False``). estimator_name (str): Unique name for the estimator. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. max_retry_attempts (int): The number of times to retry the job. environment (dict[str, str]) : Environment variables to be set for use during training jobs (default: ``None``) Returns: A dictionary of training job configuration. For format details, please refer to TrainingJobDefinition as described in https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job """ if hpo_resource_config is not None: resource_config_map = {"HyperParameterTuningResourceConfig": hpo_resource_config} else: resource_config_map = {"ResourceConfig": resource_config} training_job_definition = { "StaticHyperParameters": static_hyperparameters, "RoleArn": role, "OutputDataConfig": output_config, "StoppingCondition": stop_condition, **resource_config_map, } algorithm_spec = {"TrainingInputMode": input_mode} if metric_definitions is not None: algorithm_spec["MetricDefinitions"] = metric_definitions if algorithm_arn: algorithm_spec["AlgorithmName"] = algorithm_arn else: algorithm_spec["TrainingImage"] = image_uri training_job_definition["AlgorithmSpecification"] = algorithm_spec if input_config is not None: training_job_definition["InputDataConfig"] = input_config if vpc_config is not None: training_job_definition["VpcConfig"] = vpc_config if enable_network_isolation: training_job_definition["EnableNetworkIsolation"] = enable_network_isolation if encrypt_inter_container_traffic: training_job_definition[ "EnableInterContainerTrafficEncryption" ] = encrypt_inter_container_traffic if use_spot_instances: # use_spot_instances may be a Pipeline ParameterBoolean object # which is parsed during the Pipeline execution runtime training_job_definition["EnableManagedSpotTraining"] = use_spot_instances if checkpoint_s3_uri: checkpoint_config = {"S3Uri": checkpoint_s3_uri} if checkpoint_local_path: checkpoint_config["LocalPath"] = checkpoint_local_path training_job_definition["CheckpointConfig"] = checkpoint_config if estimator_name is not None: training_job_definition["DefinitionName"] = estimator_name tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: training_job_definition["TuningObjective"] = tuning_objective if parameter_ranges is not None: training_job_definition["HyperParameterRanges"] = parameter_ranges if auto_parameters is not None: if parameter_ranges is None: training_job_definition["HyperParameterRanges"] = {} training_job_definition["HyperParameterRanges"]["AutoParameters"] = [ {"Name": name, "ValueHint": value} for name, value in auto_parameters.items() ] if max_retry_attempts is not None: training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts} if environment is not None: training_job_definition["Environment"] = environment return training_job_definition def stop_tuning_job(self, name): """Stop the Amazon SageMaker hyperparameter tuning job with the specified name. Args: name (str): Name of the Amazon SageMaker hyperparameter tuning job. Raises: ClientError: If an error occurs while trying to stop the hyperparameter tuning job. """ try: LOGGER.info("Stopping tuning job: %s", name) self.sagemaker_client.stop_hyper_parameter_tuning_job(HyperParameterTuningJobName=name) except ClientError as e: error_code = e.response["Error"]["Code"] # allow to pass if the job already stopped if error_code == "ValidationException": LOGGER.info("Tuning job: %s is already stopped or not running.", name) else: LOGGER.error( "Error occurred while attempting to stop tuning job: %s. Please try again.", name, ) raise def _get_transform_request( self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env, input_config, output_config, resource_config, experiment_config, tags, data_processing, model_client_config=None, batch_data_capture_config: BatchDataCaptureConfig = None, ): """Construct an dict can be used to create an Amazon SageMaker transform job. Args: job_name (str): Name of the transform job being created. model_name (str): Name of the SageMaker model being used for the transform job. strategy (str): The strategy used to decide how to batch records in a single request. Possible values are 'MultiRecord' and 'SingleRecord'. max_concurrent_transforms (int): The maximum number of HTTP requests to be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. env (dict): Environment variables to be set for use during the transform job. input_config (dict): A dictionary describing the input data (and its location) for the job. output_config (dict): A dictionary describing the output location for the job. resource_config (dict): A dictionary describing the resources to complete the job. experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. tags (list[dict]): List of tags for labeling a transform job. data_processing(dict): A dictionary describing config for combining the input data and transformed data. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. model_client_config (dict): A dictionary describing the model configuration for the job. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. batch_data_capture_config (BatchDataCaptureConfig): Configuration object which specifies the configurations related to the batch data capture for the transform job (default: None) Returns: Dict: a create transform job request dict """ transform_request = { "TransformJobName": job_name, "ModelName": model_name, "TransformInput": input_config, "TransformOutput": output_config, "TransformResources": resource_config, } if strategy is not None: transform_request["BatchStrategy"] = strategy if max_concurrent_transforms is not None: transform_request["MaxConcurrentTransforms"] = max_concurrent_transforms if max_payload is not None: transform_request["MaxPayloadInMB"] = max_payload if env is not None: transform_request["Environment"] = env if tags is not None: transform_request["Tags"] = tags if data_processing is not None: transform_request["DataProcessing"] = data_processing if experiment_config and len(experiment_config) > 0: transform_request["ExperimentConfig"] = experiment_config if model_client_config and len(model_client_config) > 0: transform_request["ModelClientConfig"] = model_client_config if batch_data_capture_config is not None: transform_request["DataCaptureConfig"] = batch_data_capture_config._to_request_dict() return transform_request def transform( self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, input_config, output_config, resource_config, experiment_config, env: Optional[Dict[str, str]] = None, tags=None, data_processing=None, model_client_config=None, batch_data_capture_config: BatchDataCaptureConfig = None, ): """Create an Amazon SageMaker transform job. Args: job_name (str): Name of the transform job being created. model_name (str): Name of the SageMaker model being used for the transform job. strategy (str): The strategy used to decide how to batch records in a single request. Possible values are 'MultiRecord' and 'SingleRecord'. max_concurrent_transforms (int): The maximum number of HTTP requests to be made to each individual transform container at one time. max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. env (dict): Environment variables to be set for use during the transform job. input_config (dict): A dictionary describing the input data (and its location) for the job. output_config (dict): A dictionary describing the output location for the job. resource_config (dict): A dictionary describing the resources to complete the job. experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. * If `TrialName` is supplied and the Trial already exists the job's Trial Component will be associated with the Trial. * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. tags (list[dict]): List of tags for labeling a transform job. data_processing(dict): A dictionary describing config for combining the input data and transformed data. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. model_client_config (dict): A dictionary describing the model configuration for the job. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. batch_data_capture_config (BatchDataCaptureConfig): Configuration object which specifies the configurations related to the batch data capture for the transform job """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, TRANSFORM_JOB, TAGS) ) batch_data_capture_config = resolve_class_attribute_from_config( None, batch_data_capture_config, "kms_key_id", TRANSFORM_JOB_KMS_KEY_ID_PATH, sagemaker_session=self, ) output_config = resolve_nested_dict_value_from_config( output_config, [KMS_KEY_ID], TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, sagemaker_session=self ) resource_config = resolve_nested_dict_value_from_config( resource_config, [VOLUME_KMS_KEY_ID], TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH, sagemaker_session=self, ) env = resolve_value_from_config( direct_input=env, config_path=TRANSFORM_JOB_ENVIRONMENT_PATH, default_value=None, sagemaker_session=self, ) transform_request = self._get_transform_request( job_name=job_name, model_name=model_name, strategy=strategy, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, input_config=input_config, output_config=output_config, resource_config=resource_config, experiment_config=experiment_config, tags=tags, data_processing=data_processing, model_client_config=model_client_config, batch_data_capture_config=batch_data_capture_config, ) def submit(request): LOGGER.info("Creating transform job with name: %s", job_name) LOGGER.debug("Transform request: %s", json.dumps(request, indent=4)) self.sagemaker_client.create_transform_job(**request) self._intercept_create_request(transform_request, submit, self.transform.__name__) def _create_model_request( self, name, role, container_defs, vpc_config=None, enable_network_isolation=False, primary_container=None, tags=None, ): # pylint: disable=redefined-outer-name """Placeholder docstring""" if container_defs and primary_container: raise ValueError("Both container_defs and primary_container can not be passed as input") if primary_container: msg = ( "primary_container is going to be deprecated in a future release. Please use " "container_defs instead." ) warnings.warn(msg, DeprecationWarning) container_defs = primary_container role = self.expand_role(role) if isinstance(container_defs, list): update_list_of_dicts_with_values_from_config( container_defs, MODEL_CONTAINERS_PATH, sagemaker_session=self ) container_definition = container_defs else: container_definition = _expand_container_def(container_defs) container_definition = update_nested_dictionary_with_values_from_config( container_definition, MODEL_PRIMARY_CONTAINER_PATH, sagemaker_session=self ) request = {"ModelName": name, "ExecutionRoleArn": role} if isinstance(container_definition, list): request["Containers"] = container_definition elif "ModelPackageName" in container_definition: request["Containers"] = [container_definition] else: request["PrimaryContainer"] = container_definition if tags: request["Tags"] = tags if vpc_config: request["VpcConfig"] = vpc_config if enable_network_isolation: # enable_network_isolation may be a pipeline variable which is # parsed in execution time request["EnableNetworkIsolation"] = enable_network_isolation return request def create_model( self, name, role=None, container_defs=None, vpc_config=None, enable_network_isolation=None, primary_container=None, tags=None, ): """Create an Amazon SageMaker ``Model``. Specify the S3 location of the model artifacts and Docker image containing the inference code. Amazon SageMaker uses this information to deploy the model in Amazon SageMaker. This method can also be used to create a Model for an Inference Pipeline if you pass the list of container definitions through the containers parameter. Args: name (str): Name of the Amazon SageMaker ``Model`` to create. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. container_defs (list[dict[str, str]] or [dict[str, str]]): A single container definition or a list of container definitions which will be invoked sequentially while performing the prediction. If the list contains only one container, then it'll be passed to SageMaker Hosting as the ``PrimaryContainer`` and otherwise, it'll be passed as ``Containers``.You can also specify the return value of ``sagemaker.get_container_def()`` or ``sagemaker.pipeline_container_def()``, which will used to create more advanced container configurations, including model containers which need artifacts from S3. vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None) * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. enable_network_isolation (bool): Whether the model requires network isolation or not. primary_container (str or dict[str, str]): Docker image which defines the inference code. You can also specify the return value of ``sagemaker.container_def()``, which is used to create more advanced container configurations, including model containers which need artifacts from S3. This field is deprecated, please use container_defs instead. tags(List[dict[str, str]]): Optional. The list of tags to add to the model. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation\ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags Returns: str: Name of the Amazon SageMaker ``Model`` created. """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS)) role = resolve_value_from_config( role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self ) vpc_config = resolve_value_from_config( vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self ) enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, sagemaker_session=self, ) # Due to ambuiguity in container_defs which accepts both a single # container definition(dtype: dict) and a list of container definitions (dtype: list), # we need to inject environment variables into the container_defs in the helper function # _create_model_request. create_model_request = self._create_model_request( name=name, role=role, container_defs=container_defs, vpc_config=vpc_config, enable_network_isolation=enable_network_isolation, primary_container=primary_container, tags=tags, ) def submit(request): LOGGER.info("Creating model with name: %s", name) LOGGER.debug("CreateModel request: %s", json.dumps(request, indent=4)) try: self.sagemaker_client.create_model(**request) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] if ( error_code == "ValidationException" and "Cannot create already existing model" in message ): LOGGER.warning("Using already existing model: %s", name) else: raise self._intercept_create_request(create_model_request, submit, self.create_model.__name__) return name def create_model_from_job( self, training_job_name, name=None, role=None, image_uri=None, model_data_url=None, env=None, enable_network_isolation=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, tags=None, ): """Create an Amazon SageMaker ``Model`` from a SageMaker Training Job. Args: training_job_name (str): The Amazon SageMaker Training Job name. name (str): The name of the SageMaker ``Model`` to create (default: None). If not specified, the training job name is used. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, specified either by an IAM role name or role ARN. If None, the ``RoleArn`` from the SageMaker Training Job will be used. image_uri (str): The Docker image URI (default: None). If None, it defaults to the training image URI from ``training_job_name``. model_data_url (str): S3 location of the model data (default: None). If None, defaults to the ``ModelS3Artifacts`` of ``training_job_name``. env (dict[string,string]): Model environment variables (default: {}). enable_network_isolation (bool): Whether the model requires network isolation or not. vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model. Default: use VpcConfig from training job. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. tags(List[dict[str, str]]): Optional. The list of tags to add to the model. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: str: The name of the created ``Model``. """ training_job = self.sagemaker_client.describe_training_job( TrainingJobName=training_job_name ) name = name or training_job_name role = role or training_job["RoleArn"] role = resolve_value_from_config( role, MODEL_EXECUTION_ROLE_ARN_PATH, training_job["RoleArn"], self ) enable_network_isolation = resolve_value_from_config( direct_input=enable_network_isolation, config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH, default_value=False, sagemaker_session=self, ) env = resolve_value_from_config( env, MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH, default_value={}, sagemaker_session=self, ) primary_container = container_def( image_uri or training_job["AlgorithmSpecification"]["TrainingImage"], model_data_url=model_data_url or self._gen_s3_model_data_source(training_job), env=env, ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) vpc_config = resolve_value_from_config( vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self ) return self.create_model( name, role, primary_container, enable_network_isolation=enable_network_isolation, vpc_config=vpc_config, tags=tags, ) def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data): """Create a SageMaker Model Package from the results of training with an Algorithm Package. Args: name (str): ModelPackage name description (str): Model Package description algorithm_arn (str): arn or name of the algorithm used for training. model_data (str): s3 URI to the model artifacts produced by training """ request = { "ModelPackageName": name, "ModelPackageDescription": description, "SourceAlgorithmSpecification": { "SourceAlgorithms": [{"AlgorithmName": algorithm_arn, "ModelDataUrl": model_data}] }, } try: LOGGER.info("Creating model package with name: %s", name) self.sagemaker_client.create_model_package(**request) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] if error_code == "ValidationException" and "ModelPackage already exists" in message: LOGGER.warning("Using already existing model package: %s", name) else: raise def create_model_package_from_containers( self, containers=None, content_types=None, response_types=None, inference_instances=None, transform_instances=None, model_package_name=None, model_package_group_name=None, model_metrics=None, metadata_properties=None, marketplace_cert=False, approval_status="PendingManualApproval", description=None, drift_check_baselines=None, customer_metadata_properties=None, validation_specification=None, domain=None, sample_payload_url=None, task=None, ): """Get request dictionary for CreateModelPackage API. Args: containers (list): A list of inference containers that can be used for inference specifications of Model Package (default: None). content_types (list): The supported MIME types for the input data (default: None). response_types (list): The supported MIME types for the output data (default: None). inference_instances (list): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). metadata_properties (MetadataProperties): MetadataProperties object (default: None) marketplace_cert (bool): A boolean value indicating if the Model Package is certified for AWS Marketplace (default: False). approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). sample_payload_url (str): The S3 path where the sample payload is stored (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. # If Containers are not provided, it is safe to ignore. This is because, # if this object is provided to the API, then Image is required for Containers. # That is not supported by the config now. So if we merge values from config, # then API will throw an exception. In the future, when SageMaker Config starts # supporting other parameters we can add that. update_list_of_dicts_with_values_from_config( containers, MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH, required_key_paths=["Image"], sagemaker_session=self, ) if validation_specification: # ValidationSpecification is provided. Now we can merge missing entries from config. # If ValidationSpecification is not provided, it is safe to ignore. This is because, # if this object is provided to the API, then both ValidationProfiles and ValidationRole # are required and for ValidationProfile, ProfileName is a required parameter. That is # not supported by the config now. So if we merge values from config, then API will # throw an exception. In the future, when SageMaker Config starts supporting other # parameters we can add that. validation_role = resolve_value_from_config( validation_specification.get(VALIDATION_ROLE, None), MODEL_PACKAGE_VALIDATION_ROLE_PATH, sagemaker_session=self, ) validation_specification[VALIDATION_ROLE] = validation_role validation_profiles = validation_specification.get(VALIDATION_PROFILES, []) update_list_of_dicts_with_values_from_config( validation_profiles, MODEL_PACKAGE_VALIDATION_PROFILES_PATH, required_key_paths=["ProfileName", "TransformJobDefinition"], sagemaker_session=self, ) model_pkg_request = get_create_model_package_request( model_package_name, model_package_group_name, containers, content_types, response_types, inference_instances, transform_instances, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, sample_payload_url=sample_payload_url, task=task, ) def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): _create_resource( lambda: self.sagemaker_client.create_model_package_group( ModelPackageGroupName=request["ModelPackageGroupName"] ) ) return self.sagemaker_client.create_model_package(**request) return self._intercept_create_request( model_pkg_request, submit, self.create_model_package_from_containers.__name__ ) def wait_for_model_package(self, model_package_name, poll=5): """Wait for an Amazon SageMaker endpoint deployment to complete. Args: endpoint (str): Name of the ``Endpoint`` to wait for. poll (int): Polling interval in seconds (default: 5). Returns: dict: Return value from the ``DescribeEndpoint`` API. Raises: exceptions.CapacityError: If the Model Package job fails with CapacityError. exceptions.UnexpectedStatusException: If waiting and the Model Package job fails. """ desc = _wait_until( lambda: _create_model_package_status(self.sagemaker_client, model_package_name), poll ) status = desc["ModelPackageStatus"] if status != "Completed": reason = desc.get("FailureReason", None) message = "Error creating model package {package}: {status} Reason: {reason}".format( package=model_package_name, status=status, reason=reason ) if "CapacityError" in str(reason): raise exceptions.CapacityError( message=message, allowed_statuses=["InService"], actual_status=status, ) raise exceptions.UnexpectedStatusException( message=message, allowed_statuses=["Completed"], actual_status=status, ) return desc def describe_model(self, name): """Calls the DescribeModel API for the given model name. Args: name (str): The name of the SageMaker model. Returns: dict: A dictionary response with the model description. """ return self.sagemaker_client.describe_model(ModelName=name) def create_endpoint_config( self, name, model_name, initial_instance_count, instance_type, accelerator_type=None, tags=None, kms_key=None, data_capture_config_dict=None, volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, explainer_config_dict=None, ): """Create an Amazon SageMaker endpoint configuration. The endpoint configuration identifies the Amazon SageMaker model (created using the ``CreateModel`` API) and the hardware configuration on which to deploy the model. Provide this endpoint configuration to the ``CreateEndpoint`` API, which then launches the hardware and deploys the model. Args: name (str): Name of the Amazon SageMaker endpoint configuration to create. model_name (str): Name of the Amazon SageMaker ``Model``. initial_instance_count (int): Minimum number of EC2 instances to launch. The actual number of active instances for an endpoint at any given time varies due to autoscaling. instance_type (str): Type of EC2 instance to launch, for example, 'ml.c4.xlarge'. accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint config. kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. data_capture_config_dict (dict): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. volume_size (int): The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Currenly only Amazon EBS gp2 storage volumes are supported. model_data_download_timeout (int): The timeout value, in seconds, to download and extract model data from Amazon S3 to the individual inference instance associated with this production variant. container_startup_health_check_timeout (int): The timeout value, in seconds, for your inference container to pass health check by SageMaker Hosting. For more information about health check see: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests explainer_config_dict (dict): Specifies configuration to enable explainers. Default: None. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags Returns: str: Name of the endpoint point configuration created. """ LOGGER.info("Creating endpoint-config with name %s", name) tags = tags or [] provided_production_variant = production_variant( model_name, instance_type, initial_instance_count, accelerator_type=accelerator_type, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, ) production_variants = [provided_production_variant] # Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant. # But if that parameter is injected, then CoreDumpConfig.DestinationS3Uri needs to be # present. # But SageMaker Python SDK doesn't support CoreDumpConfig.DestinationS3Uri. update_list_of_dicts_with_values_from_config( production_variants, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, required_key_paths=["CoreDumpConfig.DestinationS3Uri"], sagemaker_session=self, ) request = { "EndpointConfigName": name, "ProductionVariants": production_variants, } tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) ) if tags is not None: request["Tags"] = tags kms_key = ( resolve_value_from_config( kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self ) if instance_supports_kms(instance_type) else kms_key ) if kms_key is not None: request["KmsKeyId"] = kms_key if data_capture_config_dict is not None: inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self ) request["DataCaptureConfig"] = inferred_data_capture_config_dict if explainer_config_dict is not None: request["ExplainerConfig"] = explainer_config_dict self.sagemaker_client.create_endpoint_config(**request) return name def create_endpoint_config_from_existing( self, existing_config_name, new_config_name, new_tags=None, new_kms_key=None, new_data_capture_config_dict=None, new_production_variants=None, new_explainer_config_dict=None, ): """Create an Amazon SageMaker endpoint configuration from an existing one. It also updates any values that were passed in. The endpoint configuration identifies the Amazon SageMaker model (created using the ``CreateModel`` API) and the hardware configuration on which to deploy the model. Provide this endpoint configuration to the ``CreateEndpoint`` API, which then launches the hardware and deploys the model. Args: new_config_name (str): Name of the Amazon SageMaker endpoint configuration to create. existing_config_name (str): Name of the existing Amazon SageMaker endpoint configuration. new_tags (list[dict[str, str]]): Optional. The list of tags to add to the endpoint config. If not specified, the tags of the existing endpoint configuration are used. If any of the existing tags are reserved AWS ones (i.e. begin with "aws"), they are not carried over to the new endpoint configuration. new_kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint (default: None). If not specified, the KMS key of the existing endpoint configuration is used. new_data_capture_config_dict (dict): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring (default: None). If not specified, the data capture configuration of the existing endpoint configuration is used. new_production_variants (list[dict]): The configuration for which model(s) to host and the resources to deploy for hosting the model(s). If not specified, the ``ProductionVariants`` of the existing endpoint configuration is used. new_explainer_config_dict (dict): Specifies configuration to enable explainers. (default: None). If not specified, the explainer configuration of the existing endpoint configuration is used. Returns: str: Name of the endpoint point configuration created. """ LOGGER.info("Creating endpoint-config with name %s", new_config_name) existing_endpoint_config_desc = self.sagemaker_client.describe_endpoint_config( EndpointConfigName=existing_config_name ) request = { "EndpointConfigName": new_config_name, } production_variants = ( new_production_variants or existing_endpoint_config_desc["ProductionVariants"] ) update_list_of_dicts_with_values_from_config( production_variants, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, required_key_paths=["CoreDumpConfig.DestinationS3Uri"], sagemaker_session=self, ) request["ProductionVariants"] = production_variants request_tags = new_tags or self.list_tags( existing_endpoint_config_desc["EndpointConfigArn"] ) request_tags = _append_project_tags(request_tags) request_tags = self._append_sagemaker_config_tags( request_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) ) if request_tags: request["Tags"] = request_tags if new_kms_key is not None or existing_endpoint_config_desc.get("KmsKeyId") is not None: request["KmsKeyId"] = new_kms_key or existing_endpoint_config_desc.get("KmsKeyId") supports_kms = any( [ instance_supports_kms(production_variant["InstanceType"]) for production_variant in production_variants if "InstanceType" in production_variant ] ) if KMS_KEY_ID not in request and supports_kms: kms_key_from_config = resolve_value_from_config( config_path=ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self ) if kms_key_from_config: request[KMS_KEY_ID] = kms_key_from_config request_data_capture_config_dict = ( new_data_capture_config_dict or existing_endpoint_config_desc.get("DataCaptureConfig") ) if request_data_capture_config_dict is not None: inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( request_data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self, ) request["DataCaptureConfig"] = inferred_data_capture_config_dict async_inference_config_dict = existing_endpoint_config_desc.get( "AsyncInferenceConfig", None ) if async_inference_config_dict is not None: inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, sagemaker_session=self, ) request["AsyncInferenceConfig"] = inferred_async_inference_config_dict request_explainer_config_dict = ( new_explainer_config_dict or existing_endpoint_config_desc.get("ExplainerConfig", None) ) if request_explainer_config_dict is not None: request["ExplainerConfig"] = request_explainer_config_dict self.sagemaker_client.create_endpoint_config(**request) def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): """Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request. Once the ``Endpoint`` is created, client applications can send requests to obtain inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API. Args: endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created. config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None). Returns: str: Name of the Amazon SageMaker ``Endpoint`` created. """ LOGGER.info("Creating endpoint with name %s", endpoint_name) tags = tags or [] tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS) ) self.sagemaker_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags ) if wait: self.wait_for_endpoint(endpoint_name) return endpoint_name def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True): """Update an Amazon SageMaker ``Endpoint`` , Raise an error endpoint_name does not exist. Args: endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update. endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). Returns: str: Name of the Amazon SageMaker ``Endpoint`` being updated. Raises: ValueError: if the endpoint does not already exist """ if not _deployment_entity_exists( lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) ): raise ValueError( "Endpoint with name '{}' does not exist; please use an " "existing endpoint name".format(endpoint_name) ) self.sagemaker_client.update_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name ) if wait: self.wait_for_endpoint(endpoint_name) return endpoint_name def delete_endpoint(self, endpoint_name): """Delete an Amazon SageMaker ``Endpoint``. Args: endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete. """ LOGGER.info("Deleting endpoint with name: %s", endpoint_name) self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name) def delete_endpoint_config(self, endpoint_config_name): """Delete an Amazon SageMaker endpoint configuration. Args: endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete. """ LOGGER.info("Deleting endpoint configuration with name: %s", endpoint_config_name) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) def delete_model(self, model_name): """Delete an Amazon SageMaker Model. Args: model_name (str): Name of the Amazon SageMaker model to delete. """ LOGGER.info("Deleting model with name: %s", model_name) self.sagemaker_client.delete_model(ModelName=model_name) def list_group_resources(self, group, filters, next_token: str = ""): """To list group resources with given filters Args: group (str): The name or the ARN of the group. filters (list): Filters that needs to be applied to the list operation. """ self.resource_groups_client = self.resource_groups_client or self.boto_session.client( "resource-groups" ) return self.resource_groups_client.list_group_resources( Group=group, Filters=filters, NextToken=next_token ) def delete_resource_group(self, group): """To delete a resource group Args: group (str): The name or the ARN of the resource group to delete. """ self.resource_groups_client = self.resource_groups_client or self.boto_session.client( "resource-groups" ) return self.resource_groups_client.delete_group(Group=group) def get_resource_group_query(self, group): """To get the group query for an AWS Resource Group Args: group (str): The name or the ARN of the resource group to query. """ self.resource_groups_client = self.resource_groups_client or self.boto_session.client( "resource-groups" ) return self.resource_groups_client.get_group_query(Group=group) def get_tagging_resources(self, tag_filters, resource_type_filters): """To list the complete resources for a particular resource group tag tag_filters: filters for the tag resource_type_filters: resource filter for the tag """ self.resource_group_tagging_client = ( self.resource_group_tagging_client or self.boto_session.client("resourcegroupstaggingapi") ) resource_list = [] try: resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") while next_token is not None and next_token != "": resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters, NextToken=next_token, ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") return resource_list except ClientError as error: raise error def create_group(self, name, resource_query, tags): """To create a AWS Resource Group Args: name (str): The name of the group, which is also the identifier of the group. resource_query (str): The resource query that determines which AWS resources are members of this group tags (dict): The Tags to be attached to the Resource Group """ self.resource_groups_client = self.resource_groups_client or self.boto_session.client( "resource-groups" ) return self.resource_groups_client.create_group( Name=name, ResourceQuery=resource_query, Tags=tags ) def list_tags(self, resource_arn, max_results=50): """List the tags given an Amazon Resource Name. Args: resource_arn (str): The Amazon Resource Name (ARN) for which to get the tags list. max_results (int): The maximum number of results to include in a single page. This method takes care of that abstraction and returns a full list. """ tags_list = [] try: list_tags_response = self.sagemaker_client.list_tags( ResourceArn=resource_arn, MaxResults=max_results ) tags_list = tags_list + list_tags_response["Tags"] next_token = list_tags_response.get("nextToken") while next_token is not None: list_tags_response = self.sagemaker_client.list_tags( ResourceArn=resource_arn, MaxResults=max_results, NextToken=next_token ) tags_list = tags_list + list_tags_response["Tags"] next_token = list_tags_response.get("nextToken") non_aws_tags = [] for tag in tags_list: if "aws:" not in tag["Key"]: non_aws_tags.append(tag) return non_aws_tags except ClientError as error: print("Error retrieving tags. resource_arn: {}".format(resource_arn)) raise error def wait_for_job(self, job, poll=5): """Wait for an Amazon SageMaker training job to complete. Args: job (str): Name of the training job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeTrainingJob`` API. Raises: exceptions.CapacityError: If the training job fails with CapacityError. exceptions.UnexpectedStatusException: If the training job fails. """ desc = _wait_until_training_done( lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll ) _check_job_status(job, desc, "TrainingJobStatus") return desc def wait_for_processing_job(self, job, poll=5): """Wait for an Amazon SageMaker Processing job to complete. Args: job (str): Name of the processing job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeProcessingJob`` API. Raises: exceptions.CapacityError: If the processing job fails with CapacityError. exceptions.UnexpectedStatusException: If the processing job fails. """ desc = _wait_until(lambda: _processing_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "ProcessingJobStatus") return desc def wait_for_compilation_job(self, job, poll=5): """Wait for an Amazon SageMaker Neo compilation job to complete. Args: job (str): Name of the compilation job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeCompilationJob`` API. Raises: exceptions.CapacityError: If the compilation job fails with CapacityError. exceptions.UnexpectedStatusException: If the compilation job fails. """ desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "CompilationJobStatus") return desc def wait_for_edge_packaging_job(self, job, poll=5): """Wait for an Amazon SageMaker Edge packaging job to complete. Args: job (str): Name of the edge packaging job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeEdgePackagingJob`` API. Raises: exceptions.CapacityError: If the edge packaging job fails with CapacityError. exceptions.UnexpectedStatusException: If the edge packaging job fails. """ desc = _wait_until(lambda: _edge_packaging_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "EdgePackagingJobStatus") return desc def wait_for_tuning_job(self, job, poll=5): """Wait for an Amazon SageMaker hyperparameter tuning job to complete. Args: job (str): Name of the tuning job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeHyperParameterTuningJob`` API. Raises: exceptions.CapacityError: If the hyperparameter tuning job fails with CapacityError. exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails. """ desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "HyperParameterTuningJobStatus") return desc def describe_transform_job(self, job_name): """Calls the DescribeTransformJob API for the given job name and returns the response. Args: job_name (str): The name of the transform job to describe. Returns: dict: A dictionary response with the transform job description. """ return self.sagemaker_client.describe_transform_job(TransformJobName=job_name) def wait_for_transform_job(self, job, poll=5): """Wait for an Amazon SageMaker transform job to complete. Args: job (str): Name of the transform job to wait for. poll (int): Polling interval in seconds (default: 5). Returns: (dict): Return value from the ``DescribeTransformJob`` API. Raises: exceptions.CapacityError: If the transform job fails with CapacityError. exceptions.UnexpectedStatusException: If the transform job fails. """ desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll) _check_job_status(job, desc, "TransformJobStatus") return desc def stop_transform_job(self, name): """Stop the Amazon SageMaker hyperparameter tuning job with the specified name. Args: name (str): Name of the Amazon SageMaker batch transform job. Raises: ClientError: If an error occurs while trying to stop the batch transform job. """ try: LOGGER.info("Stopping transform job: %s", name) self.sagemaker_client.stop_transform_job(TransformJobName=name) except ClientError as e: error_code = e.response["Error"]["Code"] # allow to pass if the job already stopped if error_code == "ValidationException": LOGGER.info("Transform job: %s is already stopped or not running.", name) else: LOGGER.error("Error occurred while attempting to stop transform job: %s.", name) raise def wait_for_endpoint(self, endpoint, poll=30): """Wait for an Amazon SageMaker endpoint deployment to complete. Args: endpoint (str): Name of the ``Endpoint`` to wait for. poll (int): Polling interval in seconds (default: 30). Raises: exceptions.CapacityError: If the endpoint creation job fails with CapacityError. exceptions.UnexpectedStatusException: If the endpoint creation job fails. Returns: dict: Return value from the ``DescribeEndpoint`` API. """ desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll) status = desc["EndpointStatus"] if status != "InService": reason = desc.get("FailureReason", None) message = "Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format( endpoint=endpoint, status=status, reason=reason ) if "CapacityError" in str(reason): raise exceptions.CapacityError( message=message, allowed_statuses=["InService"], actual_status=status, ) raise exceptions.UnexpectedStatusException( message=message, allowed_statuses=["InService"], actual_status=status, ) return desc def endpoint_from_job( self, job_name, initial_instance_count, instance_type, image_uri=None, name=None, role=None, wait=True, model_environment_vars=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, accelerator_type=None, data_capture_config=None, ): """Create an ``Endpoint`` using the results of a successful training job. Specify the job name, Docker image containing the inference code, and hardware configuration to deploy the model. Internally the API, creates an Amazon SageMaker model (that describes the model artifacts and the Docker image containing inference code), endpoint configuration (describing the hardware to deploy for hosting the model), and creates an ``Endpoint`` (launches the EC2 instances and deploys the model on them). In response, the API returns the endpoint name to which you can send requests for inferences. Args: job_name (str): Name of the training job to deploy the results of. initial_instance_count (int): Minimum number of EC2 instances to launch. The actual number of active instances for an endpoint at any given time varies due to autoscaling. instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction, for example, 'ml.c4.xlarge'. image_uri (str): The Docker image which defines the inference code to be used as the entry point for accepting prediction requests. If not specified, uses the image used for the training job. name (str): Name of the ``Endpoint`` to create. If not specified, uses the training job name. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). model_environment_vars (dict[str, str]): Environment variables to set on the model container (default: None). vpc_config_override (dict[str, list[str]]): Overrides VpcConfig set on the model. Default: use VpcConfig from training job. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. Returns: str: Name of the ``Endpoint`` that is created. """ job_desc = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) model_s3_location = self._gen_s3_model_data_source(job_desc) image_uri = image_uri or job_desc["AlgorithmSpecification"]["TrainingImage"] role = role or job_desc["RoleArn"] name = name or job_name vpc_config_override = _vpc_config_from_training_job(job_desc, vpc_config_override) return self.endpoint_from_model_data( model_s3_location=model_s3_location, image_uri=image_uri, initial_instance_count=initial_instance_count, instance_type=instance_type, name=name, role=role, wait=wait, model_environment_vars=model_environment_vars, model_vpc_config=vpc_config_override, accelerator_type=accelerator_type, data_capture_config=data_capture_config, ) def _gen_s3_model_data_source(self, training_job_spec): """Generates ``ModelDataSource`` value from given DescribeTrainingJob API response. Args: training_job_spec (dict): SageMaker DescribeTrainingJob API response. Returns: dict: A ``ModelDataSource`` value. """ model_data_s3_uri = training_job_spec["ModelArtifacts"]["S3ModelArtifacts"] compression_type = training_job_spec.get("OutputDataConfig", {}).get( "CompressionType", "GZIP" ) # See https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_OutputDataConfig.html # and https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3ModelDataSource.html if compression_type in {"NONE", "GZIP"}: model_compression_type = compression_type.title() else: raise ValueError( f'Unrecognized training job output data compression type "{compression_type}"' ) s3_model_data_type = "S3Object" if model_compression_type == "Gzip" else "S3Prefix" # if model data is in S3Prefix type and has no trailing forward slash in its URI, # append one so that it meets SageMaker Hosting's mandate for deploying uncompressed model. if s3_model_data_type == "S3Prefix" and not model_data_s3_uri.endswith("/"): model_data_s3_uri += "/" return { "S3DataSource": { "S3Uri": model_data_s3_uri, "S3DataType": s3_model_data_type, "CompressionType": model_compression_type, } } def endpoint_from_model_data( self, model_s3_location, image_uri, initial_instance_count, instance_type, name=None, role=None, wait=True, model_environment_vars=None, model_vpc_config=None, accelerator_type=None, data_capture_config=None, tags=None, ): """Create and deploy to an ``Endpoint`` using existing model data stored in S3. Args: model_s3_location (str or dict): S3 location of the model artifacts to use for the endpoint. image_uri (str): The Docker image URI which defines the runtime code to be used as the entry point for accepting prediction requests. initial_instance_count (int): Minimum number of EC2 instances to launch. The actual number of active instances for an endpoint at any given time varies due to autoscaling. instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction, e.g. 'ml.c4.xlarge'. name (str): Name of the ``Endpoint`` to create. If not specified, uses a name generated by combining the image name with a timestamp. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). model_environment_vars (dict[str, str]): Environment variables to set on the model container (default: None). model_vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None) * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None). Returns: str: Name of the ``Endpoint`` that is created. """ model_environment_vars = model_environment_vars or {} name = name or name_from_image(image_uri) model_vpc_config = vpc_utils.sanitize(model_vpc_config) endpoint_config_tags = _append_project_tags(tags) endpoint_tags = _append_project_tags(tags) endpoint_config_tags = self._append_sagemaker_config_tags( endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) ) primary_container = container_def( image_uri=image_uri, model_data_url=model_s3_location, env=model_environment_vars, ) self.create_model( name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config ) data_capture_config_dict = None if data_capture_config is not None: data_capture_config_dict = data_capture_config._to_request_dict() _create_resource( lambda: self.create_endpoint_config( name=name, model_name=name, initial_instance_count=initial_instance_count, instance_type=instance_type, accelerator_type=accelerator_type, data_capture_config_dict=data_capture_config_dict, tags=endpoint_config_tags, ) ) # to make change backwards compatible response = _create_resource( lambda: self.create_endpoint( endpoint_name=name, config_name=name, tags=endpoint_tags, wait=wait ) ) if not response: raise ValueError( 'Endpoint with name "{}" already exists; please pick a different name.'.format(name) ) return name def endpoint_from_production_variants( self, name, production_variants, tags=None, kms_key=None, wait=True, data_capture_config_dict=None, async_inference_config_dict=None, explainer_config_dict=None, ): """Create an SageMaker ``Endpoint`` from a list of production variants. Args: name (str): The name of the ``Endpoint`` to create. production_variants (list[dict[str, str]]): The list of production variants to deploy. tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None). kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True). data_capture_config_dict (dict): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. async_inference_config_dict (dict) : specifies configuration related to async endpoint. Use this configuration when trying to create async endpoint and make async inference (default: None) explainer_config_dict (dict) : Specifies configuration related to explainer. Use this configuration when trying to use online explainability. (default: None) Returns: str: The name of the created ``Endpoint``. """ supports_kms = any( [ instance_supports_kms(production_variant["InstanceType"]) for production_variant in production_variants if "InstanceType" in production_variant ] ) update_list_of_dicts_with_values_from_config( production_variants, ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH, required_key_paths=["CoreDumpConfig.DestinationS3Uri"], sagemaker_session=self, ) config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} kms_key = ( resolve_value_from_config( kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self ) if supports_kms else kms_key ) endpoint_config_tags = _append_project_tags(tags) endpoint_tags = _append_project_tags(tags) endpoint_config_tags = self._append_sagemaker_config_tags( endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS) ) if endpoint_config_tags: config_options["Tags"] = endpoint_config_tags if kms_key: config_options["KmsKeyId"] = kms_key if data_capture_config_dict is not None: inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config( data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self ) config_options["DataCaptureConfig"] = inferred_data_capture_config_dict if async_inference_config_dict is not None: inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( async_inference_config_dict, ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, sagemaker_session=self, ) config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict if explainer_config_dict is not None: config_options["ExplainerConfig"] = explainer_config_dict LOGGER.info("Creating endpoint-config with name %s", name) self.sagemaker_client.create_endpoint_config(**config_options) return self.create_endpoint( endpoint_name=name, config_name=name, tags=endpoint_tags, wait=wait ) def expand_role(self, role): """Expand an IAM role name into an ARN. If the role is already in the form of an ARN, then the role is simply returned. Otherwise we retrieve the full ARN and return it. Args: role (str): An AWS IAM role (either name or full ARN). Returns: str: The corresponding AWS IAM role ARN. """ if "/" in role: return role return self.boto_session.resource("iam").Role(role).arn def get_caller_identity_arn(self): """Returns the ARN user or role whose credentials are used to call the API. Returns: str: The ARN user or role """ if os.path.exists(NOTEBOOK_METADATA_FILE): with open(NOTEBOOK_METADATA_FILE, "rb") as f: metadata = json.loads(f.read()) instance_name = metadata["ResourceName"] domain_id = metadata.get("DomainId") user_profile_name = metadata.get("UserProfileName") space_name = metadata.get("SpaceName") try: if domain_id is None: instance_desc = self.sagemaker_client.describe_notebook_instance( NotebookInstanceName=instance_name ) return instance_desc["RoleArn"] # In Space app, find execution role from DefaultSpaceSettings on domain level if space_name is not None: domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id) return domain_desc["DefaultSpaceSettings"]["ExecutionRole"] user_profile_desc = self.sagemaker_client.describe_user_profile( DomainId=domain_id, UserProfileName=user_profile_name ) # First, try to find role in userSettings if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"): return user_profile_desc["UserSettings"]["ExecutionRole"] # If not found, fallback to the domain domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id) return domain_desc["DefaultUserSettings"]["ExecutionRole"] except ClientError: LOGGER.debug( "Couldn't call 'describe_notebook_instance' to get the Role " "ARN of the instance %s.", instance_name, ) assumed_role = self.boto_session.client( "sts", region_name=self.boto_region_name, endpoint_url=sts_regional_endpoint(self.boto_region_name), ).get_caller_identity()["Arn"] role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role) # Call IAM to get the role's path role_name = role[role.rfind("/") + 1 :] try: role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"] except ClientError: LOGGER.warning( "Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.", role_name, ) # This conditional has been present since the inception of SageMaker # Guessing this conditional's purpose was to handle lack of IAM permissions # https://github.com/aws/sagemaker-python-sdk/issues/2089#issuecomment-791802713 if "AmazonSageMaker-ExecutionRole" in assumed_role: LOGGER.warning( "Assuming role was created in SageMaker AWS console, " "as the name contains `AmazonSageMaker-ExecutionRole`. " "Defaulting to Role ARN with service-role in path. " "If this Role ARN is incorrect, please add " "IAM read permissions to your role or supply the " "Role Arn directly." ) role = re.sub( r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/service-role/\3", assumed_role, ) return role def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=None): """Display logs for a given training job, optionally tailing them until job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. Args: job_name (str): Name of the training job to display the logs for. wait (bool): Whether to keep looking for new log entries until the job completes (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). log_type ([str]): A list of strings specifying which logs to print. Acceptable strings are "All", "None", "Training", or "Rules". To maintain backwards compatibility, boolean values are also accepted and converted to strings. timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by default. Raises: exceptions.CapacityError: If the training job fails with CapacityError. exceptions.UnexpectedStatusException: If waiting and the training job fails. """ _logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout) def logs_for_processing_job(self, job_name, wait=False, poll=10): """Display logs for a given processing job, optionally tailing them until the is complete. Args: job_name (str): Name of the processing job to display the logs for. wait (bool): Whether to keep looking for new log entries until the job completes (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). Raises: ValueError: If the processing job fails. """ description = _wait_until(lambda: self.describe_processing_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self.boto_session, description, job="Processing" ) state = _get_initial_job_state(description, "ProcessingJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were # called with wait == False, we never check the job status. # # If wait == TRUE and job is not completed, the initial state is TAILING # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is # complete). # # The state table: # # STATE ACTIONS CONDITION NEW STATE # ---------------- ---------------- ----------------- ---------------- # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE # Else TAILING # JOB_COMPLETE Read logs, Pause Any COMPLETE # COMPLETE Read logs, Exit N/A # # Notes: # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to # Cloudwatch after the job was marked complete. last_describe_job_call = time.time() while True: _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap, ) if state == LogState.COMPLETE: break time.sleep(poll) if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: description = self.sagemaker_client.describe_processing_job( ProcessingJobName=job_name ) last_describe_job_call = time.time() status = description["ProcessingJobStatus"] if status in ("Completed", "Failed", "Stopped"): print() state = LogState.JOB_COMPLETE if wait: _check_job_status(job_name, description, "ProcessingJobStatus") if dot: print() def logs_for_transform_job(self, job_name, wait=False, poll=10): """Display logs for a given training job, optionally tailing them until job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. Args: job_name (str): Name of the transform job to display the logs for. wait (bool): Whether to keep looking for new log entries until the job completes (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). Raises: ValueError: If the transform job fails. """ description = _wait_until(lambda: self.describe_transform_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self.boto_session, description, job="Transform" ) state = _get_initial_job_state(description, "TransformJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were # called with wait == False, we never check the job status. # # If wait == TRUE and job is not completed, the initial state is TAILING # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is # complete). # # The state table: # # STATE ACTIONS CONDITION NEW STATE # ---------------- ---------------- ----------------- ---------------- # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE # Else TAILING # JOB_COMPLETE Read logs, Pause Any COMPLETE # COMPLETE Read logs, Exit N/A # # Notes: # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to # Cloudwatch after the job was marked complete. last_describe_job_call = time.time() while True: _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap, ) if state == LogState.COMPLETE: break time.sleep(poll) if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: description = self.sagemaker_client.describe_transform_job( TransformJobName=job_name ) last_describe_job_call = time.time() status = description["TransformJobStatus"] if status in ("Completed", "Failed", "Stopped"): print() state = LogState.JOB_COMPLETE if wait: _check_job_status(job_name, description, "TransformJobStatus") if dot: print() def delete_feature_group(self, feature_group_name: str): """Deletes a FeatureGroup in the FeatureStore service. Args: feature_group_name (str): name of the feature group to be deleted. """ self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name) def create_feature_group( self, feature_group_name: str, record_identifier_name: str, event_time_feature_name: str, feature_definitions: Sequence[Dict[str, str]], role_arn: str = None, online_store_config: Dict[str, str] = None, offline_store_config: Dict[str, str] = None, description: str = None, tags: List[Dict[str, str]] = None, ) -> Dict[str, Any]: """Creates a FeatureGroup in the FeatureStore service. Args: feature_group_name (str): name of the FeatureGroup. record_identifier_name (str): name of the record identifier feature. event_time_feature_name (str): name of the event time feature. feature_definitions (Sequence[Dict[str, str]]): list of feature definitions. role_arn (str): ARN of the role will be used to execute the api. online_store_config (Dict[str, str]): dict contains configuration of the feature online store. offline_store_config (Dict[str, str]): dict contains configuration of the feature offline store. description (str): description of the FeatureGroup. tags (List[Dict[str, str]]): list of tags for labeling a FeatureGroup. Returns: Response dict from service. """ tags = _append_project_tags(tags) tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) ) role_arn = resolve_value_from_config( role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self ) inferred_online_store_from_config = update_nested_dictionary_with_values_from_config( online_store_config, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=self ) if inferred_online_store_from_config is not None: # OnlineStore should be handled differently because if you set KmsKeyId, then you # need to set EnableOnlineStore key as well inferred_online_store_from_config["EnableOnlineStore"] = True inferred_offline_store_from_config = update_nested_dictionary_with_values_from_config( offline_store_config, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=self ) kwargs = dict( FeatureGroupName=feature_group_name, RecordIdentifierFeatureName=record_identifier_name, EventTimeFeatureName=event_time_feature_name, FeatureDefinitions=feature_definitions, RoleArn=role_arn, ) update_args( kwargs, OnlineStoreConfig=inferred_online_store_from_config, OfflineStoreConfig=inferred_offline_store_from_config, Description=description, Tags=tags, ) return self.sagemaker_client.create_feature_group(**kwargs) def describe_feature_group( self, feature_group_name: str, next_token: str = None, ) -> Dict[str, Any]: """Describe a FeatureGroup by name in FeatureStore service. Args: feature_group_name (str): name of the FeatureGroup to describe. next_token (str): next_token to get next page of features. Returns: Response dict from service. """ kwargs = dict(FeatureGroupName=feature_group_name) update_args(kwargs, NextToken=next_token) return self.sagemaker_client.describe_feature_group(**kwargs) def update_feature_group( self, feature_group_name: str, feature_additions: Sequence[Dict[str, str]] = None, online_store_config: Dict[str, any] = None, ) -> Dict[str, Any]: """Update a FeatureGroup either adding new features from the given feature definitions or updating online store config Args: feature_group_name (str): name of the FeatureGroup to update. feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated. Returns: Response dict from service. """ if feature_additions is None: return self.sagemaker_client.update_feature_group( FeatureGroupName=feature_group_name, OnlineStoreConfig=online_store_config, ) return self.sagemaker_client.update_feature_group( FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions ) def list_feature_groups( self, name_contains, feature_group_status_equals, offline_store_status_equals, creation_time_after, creation_time_before, sort_order, sort_by, max_results, next_token, ) -> Dict[str, Any]: """List all FeatureGroups satisfying given filters. Args: name_contains (str): A string that partially matches one or more FeatureGroups' names. Filters FeatureGroups by name. feature_group_status_equals (str): A FeatureGroup status. Filters FeatureGroups by FeatureGroup status. offline_store_status_equals (str): An OfflineStore status. Filters FeatureGroups by OfflineStore status. creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups created after a specific date and time. creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups created before a specific date and time. sort_order (str): The order in which FeatureGroups are listed. sort_by (str): The value on which the FeatureGroup list is sorted. max_results (int): The maximum number of results returned by ListFeatureGroups. next_token (str): A token to resume pagination of ListFeatureGroups results. Returns: Response dict from service. """ list_feature_groups_args = {} def check_object(key, value): if value is not None: list_feature_groups_args[key] = value check_object("NameContains", name_contains) check_object("FeatureGroupStatusEquals", feature_group_status_equals) check_object("OfflineStoreStatusEquals", offline_store_status_equals) check_object("CreationTimeAfter", creation_time_after) check_object("CreationTimeBefore", creation_time_before) check_object("SortOrder", sort_order) check_object("SortBy", sort_by) check_object("MaxResults", max_results) check_object("NextToken", next_token) return self.sagemaker_client.list_feature_groups(**list_feature_groups_args) def update_feature_metadata( self, feature_group_name: str, feature_name: str, description: str = None, parameter_additions: Sequence[Dict[str, str]] = None, parameter_removals: Sequence[str] = None, ) -> Dict[str, Any]: """Update a feature metadata and add/remove metadata. Args: feature_group_name (str): name of the FeatureGroup to update. feature_name (str): name of the feature to update. description (str): description of the feature to update. parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added. parameter_removals (Sequence[Dict[str, str]): list of feature parameter to be removed. Returns: Response dict from service. """ request = { "FeatureGroupName": feature_group_name, "FeatureName": feature_name, } if description is not None: request["Description"] = description if parameter_additions is not None: request["ParameterAdditions"] = parameter_additions if parameter_removals is not None: request["ParameterRemovals"] = parameter_removals return self.sagemaker_client.update_feature_metadata(**request) def describe_feature_metadata( self, feature_group_name: str, feature_name: str ) -> Dict[str, Any]: """Describe feature metadata by feature name in FeatureStore service. Args: feature_group_name (str): name of the FeatureGroup. feature_name (str): name of the feature. Returns: Response dict from service. """ return self.sagemaker_client.describe_feature_metadata( FeatureGroupName=feature_group_name, FeatureName=feature_name ) def search( self, resource: str, search_expression: Dict[str, any] = None, sort_by: str = None, sort_order: str = None, next_token: str = None, max_results: int = None, ) -> Dict[str, Any]: """Search for SageMaker resources satisfying given filters. Args: resource (str): The name of the Amazon SageMaker resource to search for. search_expression (Dict[str, any]): A Boolean conditional statement. Resources must satisfy this condition to be included in search results. sort_by (str): The name of the resource property used to sort the ``SearchResults``. The default is ``LastModifiedTime``. sort_order (str): How ``SearchResults`` are ordered. Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``. next_token (str): If more than ``MaxResults`` resources match the specified ``SearchExpression``, the response includes a ``NextToken``. The ``NextToken`` can be passed to the next ``SearchRequest`` to continue retrieving results. max_results (int): The maximum number of results to return. Returns: Response dict from service. """ search_args = {"Resource": resource} if search_expression: search_args["SearchExpression"] = search_expression if sort_by: search_args["SortBy"] = sort_by if sort_order: search_args["SortOrder"] = sort_order if next_token: search_args["NextToken"] = next_token if max_results: search_args["MaxResults"] = max_results return self.sagemaker_client.search(**search_args) def put_record( self, feature_group_name: str, record: Sequence[Dict[str, str]], ttl_duration: Dict[str, str] = None, ): """Puts a single record in the FeatureGroup. Args: feature_group_name (str): name of the FeatureGroup. record (Sequence[Dict[str, str]]): list of FeatureValue dicts to be ingested into FeatureStore. """ if ttl_duration: return self.sagemaker_featurestore_runtime_client.put_record( FeatureGroupName=feature_group_name, Record=record, TtlDuration=ttl_duration, ) return self.sagemaker_featurestore_runtime_client.put_record( FeatureGroupName=feature_group_name, Record=record, ) def delete_record( self, feature_group_name: str, record_identifier_value_as_string: str, event_time: str, deletion_mode: str = None, ): """Deletes a single record from the FeatureGroup. Args: feature_group_name (str): name of the FeatureGroup. record_identifier_value_as_string (str): name of the record identifier. event_time (str): a timestamp indicating when the deletion event occurred. deletion_mode: (str): deletion mode for deleting record. """ return self.sagemaker_featurestore_runtime_client.delete_record( FeatureGroupName=feature_group_name, RecordIdentifierValueAsString=record_identifier_value_as_string, EventTime=event_time, DeletionMode=deletion_mode, ) def get_record( self, record_identifier_value_as_string: str, feature_group_name: str, feature_names: Sequence[str], expiration_time_response: str = None, ) -> Dict[str, Sequence[Dict[str, str]]]: """Gets a single record in the FeatureGroup. Args: record_identifier_value_as_string (str): name of the record identifier. feature_group_name (str): name of the FeatureGroup. feature_names (Sequence[str]): list of feature names. expiration_time_response (str): the field of expiration time response to toggle returning of expiresAt. """ get_record_args = { "FeatureGroupName": feature_group_name, "RecordIdentifierValueAsString": record_identifier_value_as_string, } if expiration_time_response: get_record_args["ExpirationTimeResponse"] = expiration_time_response if feature_names: get_record_args["FeatureNames"] = feature_names return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args) def batch_get_record( self, identifiers: Sequence[Dict[str, Any]], expiration_time_response: str = None, ) -> Dict[str, Any]: """Gets a batch of record from FeatureStore. Args: identifiers (Sequence[Dict[str, Any]]): list of identifiers to uniquely identify records in FeatureStore. expiration_time_response (str): the field of expiration time response to toggle returning of expiresAt. Returns: Response dict from service. """ batch_get_record_args = {"Identifiers": identifiers} if expiration_time_response: batch_get_record_args["ExpirationTimeResponse"] = expiration_time_response return self.sagemaker_featurestore_runtime_client.batch_get_record(**batch_get_record_args) def start_query_execution( self, catalog: str, database: str, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None, ) -> Dict[str, str]: """Start Athena query execution. Args: catalog (str): name of the data catalog. database (str): name of the data catalog database. query_string (str): SQL expression. output_location (str): S3 location of the output file. kms_key (str): KMS key id will be used to encrypt the result if given. workgroup (str): The name of the workgroup in which the query is being started. If the workgroup is not specified, the default workgroup is used. Returns: Response dict from the service. """ kwargs = dict( QueryString=query_string, QueryExecutionContext=dict(Catalog=catalog, Database=database) ) result_config = dict(OutputLocation=output_location) if kms_key: result_config.update( EncryptionConfiguration=dict(EncryptionOption="SSE_KMS", KmsKey=kms_key) ) kwargs.update(ResultConfiguration=result_config) if workgroup: kwargs.update(WorkGroup=workgroup) athena_client = self.boto_session.client("athena", region_name=self.boto_region_name) return athena_client.start_query_execution(**kwargs) def get_query_execution(self, query_execution_id: str) -> Dict[str, Any]: """Get execution status of the Athena query. Args: query_execution_id (str): execution ID of the Athena query. """ athena_client = self.boto_session.client("athena", region_name=self.boto_region_name) return athena_client.get_query_execution(QueryExecutionId=query_execution_id) def wait_for_athena_query(self, query_execution_id: str, poll: int = 5): """Wait for Athena query to finish. Args: query_execution_id (str): execution ID of the Athena query. poll (int): time interval to poll get_query_execution API. """ query_state = ( self.get_query_execution(query_execution_id=query_execution_id) .get("QueryExecution") .get("Status") .get("State") ) while query_state not in ("SUCCEEDED", "FAILED"): LOGGER.info("Query %s is being executed.", query_execution_id) time.sleep(poll) query_state = ( self.get_query_execution(query_execution_id=query_execution_id) .get("QueryExecution") .get("Status") .get("State") ) if query_state == "SUCCEEDED": LOGGER.info("Query %s successfully executed.", query_execution_id) else: LOGGER.error("Failed to execute query %s.", query_execution_id) def download_athena_query_result( self, bucket: str, prefix: str, query_execution_id: str, filename: str, ): """Download query result file from S3. Args: bucket (str): name of the S3 bucket where the result file is stored. prefix (str): S3 prefix of the result file. query_execution_id (str): execution ID of the Athena query. filename (str): name of the downloaded file. """ if self.s3_client is None: s3 = self.boto_session.client("s3", region_name=self.boto_region_name) else: s3 = self.s3_client s3.download_file(Bucket=bucket, Key=f"{prefix}/{query_execution_id}.csv", Filename=filename) def account_id(self) -> str: """Get the AWS account id of the caller. Returns: AWS account ID. """ region = self.boto_session.region_name sts_client = self.boto_session.client( "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) ) return sts_client.get_caller_identity()["Account"] def _intercept_create_request( self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument ): """This function intercepts the create job request. PipelineSession inherits this Session class and will override this function to intercept the create request. Args: request (dict): the create job request create (functor): a functor calls the sagemaker client create method func_name (str): the name of the function needed intercepting """ return create(request) def _create_inference_recommendations_job_request( self, role: str, job_name: str, job_description: str, framework: str, sample_payload_url: str, supported_content_types: List[str], tags: Dict[str, str], model_name: str = None, model_package_version_arn: str = None, job_duration_in_seconds: int = None, job_type: str = "Default", framework_version: str = None, nearest_model_name: str = None, supported_instance_types: List[str] = None, endpoint_configurations: List[Dict[str, Any]] = None, traffic_pattern: Dict[str, Any] = None, stopping_conditions: Dict[str, Any] = None, resource_limit: Dict[str, Any] = None, ) -> Dict[str, Any]: """Get request dictionary for CreateInferenceRecommendationsJob API. Args: role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. job_name (str): The name of the Inference Recommendations Job. job_description (str): A description of the Inference Recommendations Job. framework (str): The machine learning framework of the Image URI. sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types (List[str]): The supported MIME types for the input data. model_name (str): Name of the Amazon SageMaker ``Model`` to be used. model_package_version_arn (str): The Amazon Resource Name (ARN) of a versioned model package. job_duration_in_seconds (int): The maximum job duration that a job can run for. Will be used for `Advanced` jobs. job_type (str): The type of job being run. Must either be `Default` or `Advanced`. framework_version (str): The framework version of the Image URI. nearest_model_name (str): The name of a pre-trained machine learning model benchmarked by Amazon SageMaker Inference Recommender that matches your model. supported_instance_types (List[str]): A list of the instance types that are used to generate inferences in real-time. tags (Dict[str, str]): Tags used to identify where the Inference Recommendatons Call was made from. endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations to use for a job. Will be used for `Advanced` jobs. traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job. Will be used for `Advanced` jobs. stopping_conditions (Dict[str, any]): A set of conditions for stopping a recommendation job. If any of the conditions are met, the job is automatically stopped. Will be used for `Advanced` jobs. resource_limit (Dict[str, any]): Defines the resource limit for the job. Will be used for `Advanced` jobs. Returns: Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API """ containerConfig = { "Domain": "MACHINE_LEARNING", "Task": "OTHER", "Framework": framework, "PayloadConfig": { "SamplePayloadUrl": sample_payload_url, "SupportedContentTypes": supported_content_types, }, } if framework_version: containerConfig["FrameworkVersion"] = framework_version if nearest_model_name: containerConfig["NearestModelName"] = nearest_model_name if supported_instance_types: containerConfig["SupportedInstanceTypes"] = supported_instance_types request = { "JobName": job_name, "JobType": job_type, "RoleArn": role, "InputConfig": { "ContainerConfig": containerConfig, }, "Tags": tags, } request.get("InputConfig").update( {"ModelPackageVersionArn": model_package_version_arn} if model_package_version_arn else {"ModelName": model_name} ) if job_description: request["JobDescription"] = job_description if job_duration_in_seconds: request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds if job_type == "Advanced": if stopping_conditions: request["StoppingConditions"] = stopping_conditions if resource_limit: request["InputConfig"]["ResourceLimit"] = resource_limit if traffic_pattern: request["InputConfig"]["TrafficPattern"] = traffic_pattern if endpoint_configurations: request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations return request def create_inference_recommendations_job( self, role: str, sample_payload_url: str, supported_content_types: List[str], job_name: str = None, job_type: str = "Default", model_name: str = None, model_package_version_arn: str = None, job_duration_in_seconds: int = None, nearest_model_name: str = None, supported_instance_types: List[str] = None, framework: str = None, framework_version: str = None, endpoint_configurations: List[Dict[str, any]] = None, traffic_pattern: Dict[str, any] = None, stopping_conditions: Dict[str, any] = None, resource_limit: Dict[str, any] = None, ): """Creates an Inference Recommendations Job Args: role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. You must grant sufficient permissions to this role. sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types (List[str]): The supported MIME types for the input data. model_name (str): Name of the Amazon SageMaker ``Model`` to be used. model_package_version_arn (str): The Amazon Resource Name (ARN) of a versioned model package. job_name (str): The name of the job being run. job_type (str): The type of job being run. Must either be `Default` or `Advanced`. job_duration_in_seconds (int): The maximum job duration that a job can run for. Will be used for `Advanced` jobs. nearest_model_name (str): The name of a pre-trained machine learning model benchmarked by Amazon SageMaker Inference Recommender that matches your model. supported_instance_types (List[str]): A list of the instance types that are used to generate inferences in real-time. framework (str): The machine learning framework of the Image URI. framework_version (str): The framework version of the Image URI. endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations to use for a job. Will be used for `Advanced` jobs. traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job. Will be used for `Advanced` jobs. stopping_conditions (Dict[str, any]): A set of conditions for stopping a recommendation job. If any of the conditions are met, the job is automatically stopped. Will be used for `Advanced` jobs. resource_limit (Dict[str, any]): Defines the resource limit for the job. Will be used for `Advanced` jobs. Returns: str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>` """ if model_name is None and model_package_version_arn is None: raise ValueError("Please provide either model_name or model_package_version_arn.") if model_name is not None and model_package_version_arn is not None: raise ValueError("Please provide either model_name or model_package_version_arn.") if not job_name: unique_tail = uuid.uuid4() job_name = "SMPYTHONSDK-" + str(unique_tail) job_description = "#python-sdk-create" tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}] create_inference_recommendations_job_request = ( self._create_inference_recommendations_job_request( role=role, model_name=model_name, model_package_version_arn=model_package_version_arn, job_name=job_name, job_type=job_type, job_duration_in_seconds=job_duration_in_seconds, job_description=job_description, framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, sample_payload_url=sample_payload_url, supported_content_types=supported_content_types, supported_instance_types=supported_instance_types, endpoint_configurations=endpoint_configurations, traffic_pattern=traffic_pattern, stopping_conditions=stopping_conditions, resource_limit=resource_limit, tags=tags, ) ) def submit(request): LOGGER.info("Creating Inference Recommendations job with name: %s", job_name) LOGGER.debug("process request: %s", json.dumps(request, indent=4)) self.sagemaker_client.create_inference_recommendations_job(**request) self._intercept_create_request( create_inference_recommendations_job_request, submit, self.create_inference_recommendations_job.__name__, ) return job_name def wait_for_inference_recommendations_job( self, job_name: str, poll: int = 120, log_level: str = "Verbose" ) -> Dict[str, Any]: """Wait for an Amazon SageMaker Inference Recommender job to complete. Args: job_name (str): Name of the Inference Recommender job to wait for. poll (int): Polling interval in seconds (default: 120). log_level (str): The level of verbosity for the logs. Can be "Quiet" or "Verbose" (default: "Quiet"). Returns: (dict): Return value from the ``DescribeInferenceRecommendationsJob`` API. Raises: exceptions.CapacityError: If the Inference Recommender job fails with CapacityError. exceptions.UnexpectedStatusException: If the Inference Recommender job fails. """ if log_level == "Quiet": _wait_until( lambda: _describe_inference_recommendations_job_status( self.sagemaker_client, job_name ), poll, ) elif log_level == "Verbose": _display_inference_recommendations_job_steps_status( self, self.sagemaker_client, job_name ) else: raise ValueError("log_level must be either Quiet or Verbose") desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name) _check_job_status(job_name, desc, "Status") return desc def get_model_package_args( content_types, response_types, inference_instances=None, transform_instances=None, model_package_name=None, model_package_group_name=None, model_data=None, image_uri=None, model_metrics=None, metadata_properties=None, marketplace_cert=False, approval_status=None, description=None, tags=None, container_def_list=None, drift_check_baselines=None, customer_metadata_properties=None, validation_specification=None, domain=None, sample_payload_url=None, task=None, ): """Get arguments for create_model_package method. Args: content_types (list): The supported MIME types for the input data. response_types (list): The supported MIME types for the output data. inference_instances (list): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package model_data (str): s3 URI to the model artifacts from training (default: None). image_uri (str): Inference image uri for the container. Model class' self.image will be used if it is None (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). metadata_properties (MetadataProperties): MetadataProperties object (default: None). marketplace_cert (bool): A boolean value indicating if the Model Package is certified for AWS Marketplace (default: False). approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). sample_payload_url (str): The S3 path where the sample payload is stored (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). Returns: dict: A dictionary of method argument names and values. """ if container_def_list is not None: containers = container_def_list else: container = { "Image": image_uri, "ModelDataUrl": model_data, } containers = [container] model_package_args = { "containers": containers, "content_types": content_types, "response_types": response_types, "inference_instances": inference_instances, "transform_instances": transform_instances, "marketplace_cert": marketplace_cert, } if model_package_name is not None: model_package_args["model_package_name"] = model_package_name if model_package_group_name is not None: model_package_args["model_package_group_name"] = model_package_group_name if model_metrics is not None: model_package_args["model_metrics"] = model_metrics._to_request_dict() if drift_check_baselines is not None: model_package_args["drift_check_baselines"] = drift_check_baselines._to_request_dict() if metadata_properties is not None: model_package_args["metadata_properties"] = metadata_properties._to_request_dict() if approval_status is not None: model_package_args["approval_status"] = approval_status if description is not None: model_package_args["description"] = description if tags is not None: model_package_args["tags"] = tags if customer_metadata_properties is not None: model_package_args["customer_metadata_properties"] = customer_metadata_properties if validation_specification is not None: model_package_args["validation_specification"] = validation_specification if domain is not None: model_package_args["domain"] = domain if sample_payload_url is not None: model_package_args["sample_payload_url"] = sample_payload_url if task is not None: model_package_args["task"] = task return model_package_args def get_create_model_package_request( model_package_name=None, model_package_group_name=None, containers=None, content_types=None, response_types=None, inference_instances=None, transform_instances=None, model_metrics=None, metadata_properties=None, marketplace_cert=False, approval_status="PendingManualApproval", description=None, tags=None, drift_check_baselines=None, customer_metadata_properties=None, validation_specification=None, domain=None, sample_payload_url=None, task=None, ): """Get request dictionary for CreateModelPackage API. Args: model_package_name (str): Model Package name, exclusive to `model_package_group_name`, using `model_package_name` makes the Model Package un-versioned (default: None). model_package_group_name (str): Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package containers (list): A list of inference containers that can be used for inference specifications of Model Package (default: None). content_types (list): The supported MIME types for the input data (default: None). response_types (list): The supported MIME types for the output data (default: None). inference_instances (list): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). metadata_properties (MetadataProperties): MetadataProperties object (default: None). marketplace_cert (bool): A boolean value indicating if the Model Package is certified for AWS Marketplace (default: False). approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). sample_payload_url (str): The S3 path where the sample payload is stored (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ if all([model_package_name, model_package_group_name]): raise ValueError( "model_package_name and model_package_group_name cannot be present at the " "same time." ) request_dict = {} if model_package_name is not None: request_dict["ModelPackageName"] = model_package_name if model_package_group_name is not None: request_dict["ModelPackageGroupName"] = model_package_group_name if description is not None: request_dict["ModelPackageDescription"] = description if tags is not None: request_dict["Tags"] = tags if model_metrics: request_dict["ModelMetrics"] = model_metrics if drift_check_baselines: request_dict["DriftCheckBaselines"] = drift_check_baselines if metadata_properties: request_dict["MetadataProperties"] = metadata_properties if customer_metadata_properties is not None: request_dict["CustomerMetadataProperties"] = customer_metadata_properties if validation_specification: request_dict["ValidationSpecification"] = validation_specification if domain is not None: request_dict["Domain"] = domain if sample_payload_url is not None: request_dict["SamplePayloadUrl"] = sample_payload_url if task is not None: request_dict["Task"] = task if containers is not None: if not all([content_types, response_types]): raise ValueError( "content_types and response_types " "must be provided if containers is present." ) inference_specification = { "Containers": containers, "SupportedContentTypes": content_types, "SupportedResponseMIMETypes": response_types, } if model_package_group_name is not None: if inference_instances is not None: inference_specification.update( { "SupportedRealtimeInferenceInstanceTypes": inference_instances, } ) if transform_instances is not None: inference_specification.update( { "SupportedTransformInstanceTypes": transform_instances, } ) else: if not all([inference_instances, transform_instances]): raise ValueError( "inference_instances and transform_instances " "must be provided if model_package_group_name is not present." ) inference_specification.update( { "SupportedRealtimeInferenceInstanceTypes": inference_instances, "SupportedTransformInstanceTypes": transform_instances, } ) request_dict["InferenceSpecification"] = inference_specification request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status return request_dict def update_args(args: Dict[str, Any], **kwargs): """Updates the request arguments dict with the value if populated. This is to handle the case that the service API doesn't like NoneTypes for argument values. Args: request_args (Dict[str, Any]): the request arguments dict kwargs: key, value pairs to update the args dict """ for key, value in kwargs.items(): if value is not None: args.update({key: value}) def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None): """Create a definition for executing a container as part of a SageMaker model. Args: image_uri (str): Docker image URI to run for this container. model_data_url (str or dict[str, Any]): S3 location of model data required by this container, e.g. SageMaker training job model artifacts. It can either be a string representing S3 URI of model data, or a dictionary representing a ``ModelDataSource`` object. (default: None). env (dict[str, str]): Environment variables to set inside the container (default: None). container_mode (str): The model container mode. Valid modes: * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None image_config (dict[str, str]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if passed via `PrimaryContainers` field. """ if env is None: env = {} c_def = {"Image": image_uri, "Environment": env} if isinstance(model_data_url, dict): c_def["ModelDataSource"] = model_data_url elif model_data_url: c_def["ModelDataUrl"] = model_data_url if container_mode: c_def["Mode"] = container_mode if image_config: c_def["ImageConfig"] = image_config return c_def def pipeline_container_def(models, instance_type=None): """Create a definition for executing a pipeline of containers as part of a SageMaker model. Args: models (list[sagemaker.Model]): this will be a list of ``sagemaker.Model`` objects in the order the inference should be invoked. instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge' (default: None). Returns: list[dict[str, str]]: list of container definition objects usable with with the CreateModel API for inference pipelines if passed via `Containers` field. """ c_defs = [] # should contain list of container definitions in the same order customer passed for model in models: c_defs.append(model.prepare_container_def(instance_type)) return c_defs def production_variant( model_name, instance_type=None, initial_instance_count=None, variant_name="AllTraffic", initial_weight=1, accelerator_type=None, serverless_inference_config=None, volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. This is also part of a ``CreateEndpointConfig`` request. Args: model_name (str): The name of the SageMaker model this production variant references. instance_type (str): The EC2 instance type for this production variant. For example, 'ml.c4.8xlarge'. initial_instance_count (int): The initial instance count for this production variant (default: 1). variant_name (string): The ``VariantName`` of this production variant (default: 'AllTraffic'). initial_weight (int): The relative ``InitialVariantWeight`` of this production variant (default: 1). accelerator_type (str): Type of Elastic Inference accelerator for this production variant. For example, 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html serverless_inference_config (dict): Specifies configuration dict related to serverless endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig object (default: None) volume_size (int): The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Currenly only Amazon EBS gp2 storage volumes are supported. model_data_download_timeout (int): The timeout value, in seconds, to download and extract model data from Amazon S3 to the individual inference instance associated with this production variant. container_startup_health_check_timeout (int): The timeout value, in seconds, for your inference container to pass health check by SageMaker Hosting. For more information about health check see: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: dict[str, str]: An SageMaker ``ProductionVariant`` description """ production_variant_configuration = { "ModelName": model_name, "VariantName": variant_name, "InitialVariantWeight": initial_weight, } if accelerator_type: production_variant_configuration["AcceleratorType"] = accelerator_type if serverless_inference_config: production_variant_configuration["ServerlessConfig"] = serverless_inference_config else: initial_instance_count = initial_instance_count or 1 production_variant_configuration["InitialInstanceCount"] = initial_instance_count production_variant_configuration["InstanceType"] = instance_type update_args( production_variant_configuration, VolumeSizeInGB=volume_size, ModelDataDownloadTimeoutInSeconds=model_data_download_timeout, ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout, ) return production_variant_configuration def get_execution_role(sagemaker_session=None): """Return the role ARN whose credentials are used to call the API. Throws an exception if role doesn't exist. Args: sagemaker_session(Session): Current sagemaker session Returns: (str): The role ARN """ if not sagemaker_session: sagemaker_session = Session() arn = sagemaker_session.get_caller_identity_arn() if ":role/" in arn: return arn message = ( "The current AWS identity is not a role: {}, therefore it cannot be used as a " "SageMaker execution role" ) raise ValueError(message.format(arn)) def generate_default_sagemaker_bucket_name(boto_session): """Generates a name for the default sagemaker S3 bucket. Args: boto_session (boto3.session.Session): The underlying Boto3 session which AWS service """ region = boto_session.region_name account = boto_session.client( "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) ).get_caller_identity()["Account"] return "sagemaker-{}-{}".format(region, account) def _deployment_entity_exists(describe_fn): """Placeholder docstring""" try: describe_fn() return True except ClientError as ce: error_code = ce.response["Error"]["Code"] if not ( error_code == "ValidationException" and "Could not find" in ce.response["Error"]["Message"] ): raise ce return False def _create_resource(create_fn): """Call create function and accepts/pass when resource already exists. This is a helper function to use an existing resource if found when creating. Args: create_fn: Create resource function. Returns: (bool): True if new resource was created, False if resource already exists. """ try: create_fn() # create function succeeded, resource does not exist already return True except ClientError as ce: error_code = ce.response["Error"]["Code"] error_message = ce.response["Error"]["Message"] already_exists_exceptions = ["ValidationException", "ResourceInUse"] already_exists_msg_patterns = ["Cannot create already existing", "already exists"] if not ( error_code in already_exists_exceptions and any(p in error_message for p in already_exists_msg_patterns) ): raise ce # no new resource created as resource already exists return False def _train_done(sagemaker_client, job_name, last_desc): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Created"] desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) status = desc["TrainingJobStatus"] if secondary_training_status_changed(desc, last_desc): print() print(secondary_training_status_message(desc, last_desc), end="") else: print(".", end="") sys.stdout.flush() if status in in_progress_statuses: return desc, False print() return desc, True def _processing_job_status(sagemaker_client, job_name): """Prints the job status for the given processing job name. Returns the job description. Args: sagemaker_client: The boto3 SageMaker client. job_name (str): The name of the job for which the status is requested. Returns: dict: The processing job description. """ compile_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping", "Starting"] desc = sagemaker_client.describe_processing_job(ProcessingJobName=job_name) status = desc["ProcessingJobStatus"] status = _STATUS_CODE_TABLE.get(status, status) print(compile_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None return desc def _edge_packaging_job_status(sagemaker_client, job_name): """Process the current status of a packaging job. Args: sagemaker_client (boto3.client.sagemaker): a sagemaker client job_name (str): the name of the job to inspect. Returns: Dict: the status of the edge packaging job """ package_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping", "Starting"] desc = sagemaker_client.describe_edge_packaging_job(EdgePackagingJobName=job_name) status = desc["EdgePackagingJobStatus"] status = _STATUS_CODE_TABLE.get(status, status) print(package_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None return desc def _compilation_job_status(sagemaker_client, job_name): """Placeholder docstring""" compile_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping", "Starting"] desc = sagemaker_client.describe_compilation_job(CompilationJobName=job_name) status = desc["CompilationJobStatus"] status = _STATUS_CODE_TABLE.get(status, status) print(compile_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None return desc def _tuning_job_status(sagemaker_client, job_name): """Placeholder docstring""" tuning_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping"] desc = sagemaker_client.describe_hyper_parameter_tuning_job( HyperParameterTuningJobName=job_name ) status = desc["HyperParameterTuningJobStatus"] print(tuning_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None print("") return desc def _transform_job_status(sagemaker_client, job_name): """Placeholder docstring""" transform_job_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping"] desc = sagemaker_client.describe_transform_job(TransformJobName=job_name) status = desc["TransformJobStatus"] print(transform_job_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None print("") return desc def _auto_ml_job_status(sagemaker_client, job_name): """Placeholder docstring""" auto_ml_job_status_codes = { "Completed": "!", "InProgress": ".", "Failed": "*", "Stopped": "s", "Stopping": "_", } in_progress_statuses = ["InProgress", "Stopping"] desc = sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name) status = desc["AutoMLJobStatus"] print(auto_ml_job_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None print("") return desc def _create_model_package_status(sagemaker_client, model_package_name): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Pending"] desc = sagemaker_client.describe_model_package(ModelPackageName=model_package_name) status = desc["ModelPackageStatus"] print(".", end="") sys.stdout.flush() if status in in_progress_statuses: return None print("") return desc def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str): """Describes the status of a job and returns the job description. Args: sagemaker_client (boto3.client.sagemaker): A SageMaker client. job_name (str): The name of the job. Returns: dict: The job description, or None if the job is still in progress. """ inference_recommendations_job_status_codes = { "PENDING": ".", "IN_PROGRESS": ".", "COMPLETED": "!", "FAILED": "*", "STOPPING": "_", "STOPPED": "s", } in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"} desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name) status = desc["Status"] print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True) if status in in_progress_statuses: return None print("") return desc def _display_inference_recommendations_job_steps_status( sagemaker_session, sagemaker_client, job_name: str, poll: int = 60 ): """Placeholder docstring""" cloudwatch_client = sagemaker_session.boto_session.client("logs") in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"} log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs" log_stream_name = job_name + "/execution" initial_logs_batch = get_log_events_for_inference_recommender( cloudwatch_client, log_group_name, log_stream_name ) print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True) events = initial_logs_batch["events"] print(*[event["message"] for event in events], sep="\n", flush=True) next_forward_token = initial_logs_batch["nextForwardToken"] if events else None flush_remaining = True while True: logs_batch = ( cloudwatch_client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, nextToken=next_forward_token, ) if next_forward_token else cloudwatch_client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name ) ) events = logs_batch["events"] desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name) status = desc["Status"] if not events: if status in in_progress_statuses: time.sleep(poll) continue if flush_remaining: flush_remaining = False time.sleep(poll) continue next_forward_token = logs_batch["nextForwardToken"] print(*[event["message"] for event in events], sep="\n", flush=True) if status not in in_progress_statuses: break time.sleep(poll) def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name): """Retrieves log events from the specified CloudWatch log group and log stream. Args: cw_client (boto3.client): A boto3 CloudWatch client. log_group_name (str): The name of the CloudWatch log group. log_stream_name (str): The name of the CloudWatch log stream. Returns: (dict): A dictionary containing log events from CloudWatch log group and log stream. """ print("Fetching logs from CloudWatch...", flush=True) for _ in retries( max_retry_count=30, # 30*10 = 5min exception_message_prefix="Waiting for cloudwatch stream to appear. ", seconds_to_sleep=10, ): try: return cw_client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name ) except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": pass def _deploy_done(sagemaker_client, endpoint_name): """Placeholder docstring""" hosting_status_codes = { "OutOfService": "x", "Creating": "-", "Updating": "-", "InService": "!", "RollingBack": "<", "Deleting": "o", "Failed": "*", } in_progress_statuses = ["Creating", "Updating"] desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name) status = desc["EndpointStatus"] print(hosting_status_codes.get(status, "?"), end="") sys.stdout.flush() return None if status in in_progress_statuses else desc def _wait_until_training_done(callable_fn, desc, poll=5): """Placeholder docstring""" elapsed_time = 0 finished = None job_desc = desc while not finished: try: elapsed_time += poll time.sleep(poll) job_desc, finished = callable_fn(job_desc) except botocore.exceptions.ClientError as err: # For initial 5 mins we accept/pass AccessDeniedException. # The reason is to await tag propagation to avoid false AccessDenied claims for an # access policy based on resource tags, The caveat here is for true AccessDenied # cases the routine will fail after 5 mins if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300: LOGGER.warning( "Received AccessDeniedException. This could mean the IAM role does not " "have the resource permissions, in which case please add resource access " "and retry. For cases where the role has tag based resource policy, " "continuing to wait for tag propagation.." ) continue raise err return job_desc def _wait_until(callable_fn, poll=5): """Placeholder docstring""" elapsed_time = 0 result = None while result is None: try: elapsed_time += poll time.sleep(poll) result = callable_fn() except botocore.exceptions.ClientError as err: # For initial 5 mins we accept/pass AccessDeniedException. # The reason is to await tag propagation to avoid false AccessDenied claims for an # access policy based on resource tags, The caveat here is for true AccessDenied # cases the routine will fail after 5 mins if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300: LOGGER.warning( "Received AccessDeniedException. This could mean the IAM role does not " "have the resource permissions, in which case please add resource access " "and retry. For cases where the role has tag based resource policy, " "continuing to wait for tag propagation.." ) continue raise err return result def _expand_container_def(c_def): """Placeholder docstring""" if isinstance(c_def, six.string_types): return container_def(c_def) return c_def def _vpc_config_from_training_job( training_job_desc, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT ): """Placeholder docstring""" if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY) return vpc_utils.sanitize(vpc_config_override) def _get_initial_job_state(description, status_key, wait): """Placeholder docstring""" status = description[status_key] job_already_completed = status in ("Completed", "Failed", "Stopped") return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE def _rule_statuses_changed(current_statuses, last_statuses): """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules.""" if not last_statuses: return True for current, last in zip(current_statuses, last_statuses): if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and ( current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"] ): return True return False def _logs_for_job( # noqa: C901 - suppress complexity warning for this method boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None ): """Display logs for a given training job, optionally tailing them until job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. Args: boto_session (boto3.session.Session): The underlying Boto3 session which AWS service calls are delegated to (default: None). If not provided, one is created with default AWS configuration chain. job_name (str): Name of the training job to display the logs for. wait (bool): Whether to keep looking for new log entries until the job completes (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). log_type ([str]): A list of strings specifying which logs to print. Acceptable strings are "All", "None", "Training", or "Rules". To maintain backwards compatibility, boolean values are also accepted and converted to strings. timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by default. Returns: Last call to sagemaker DescribeTrainingJob Raises: exceptions.CapacityError: If the training job fails with CapacityError. exceptions.UnexpectedStatusException: If waiting and the training job fails. """ sagemaker_client = boto_session.client("sagemaker") request_end_time = time.time() + timeout if timeout else None description = sagemaker_client.describe_training_job(TrainingJobName=job_name) print(secondary_training_status_message(description, None), end="") instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( boto_session, description, job="Training" ) state = _get_initial_job_state(description, "TrainingJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were # called with wait == False, we never check the job status. # # If wait == TRUE and job is not completed, the initial state is TAILING # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is # complete). # # The state table: # # STATE ACTIONS CONDITION NEW STATE # ---------------- ---------------- ----------------- ---------------- # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE # Else TAILING # JOB_COMPLETE Read logs, Pause Any COMPLETE # COMPLETE Read logs, Exit N/A # # Notes: # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to # Cloudwatch after the job was marked complete. last_describe_job_call = time.time() last_description = description last_debug_rule_statuses = None last_profiler_rule_statuses = None while True: _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap, ) if timeout and time.time() > request_end_time: print("Timeout Exceeded. {} seconds elapsed.".format(timeout)) break if state == LogState.COMPLETE: break time.sleep(poll) if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: description = sagemaker_client.describe_training_job(TrainingJobName=job_name) last_describe_job_call = time.time() if secondary_training_status_changed(description, last_description): print() print(secondary_training_status_message(description, last_description), end="") last_description = description status = description["TrainingJobStatus"] if status in ("Completed", "Failed", "Stopped"): print() state = LogState.JOB_COMPLETE # Print prettified logs related to the status of SageMaker Debugger rules. debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {}) if ( debug_rule_statuses and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses) and (log_type in {"All", "Rules"}) ): for status in debug_rule_statuses: rule_log = ( f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" ) print(rule_log) last_debug_rule_statuses = debug_rule_statuses # Print prettified logs related to the status of SageMaker Profiler rules. profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {}) if ( profiler_rule_statuses and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses) and (log_type in {"All", "Rules"}) ): for status in profiler_rule_statuses: rule_log = ( f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" ) print(rule_log) last_profiler_rule_statuses = profiler_rule_statuses if wait: _check_job_status(job_name, description, "TrainingJobStatus") if dot: print() # Customers are not billed for hardware provisioning, so billable time is less than # total time training_time = description.get("TrainingTimeInSeconds") billable_time = description.get("BillableTimeInSeconds") if training_time is not None: print("Training seconds:", training_time * instance_count) if billable_time is not None: print("Billable seconds:", billable_time * instance_count) if description.get("EnableManagedSpotTraining"): saving = (1 - float(billable_time) / training_time) * 100 print("Managed Spot Training savings: {:.1f}%".format(saving)) return last_description def _check_job_status(job, desc, status_key_name): """Check to see if the job completed successfully. If not, construct and raise a exceptions. (UnexpectedStatusException). Args: job (str): The name of the job to check. desc (dict[str, str]): The result of ``describe_training_job()``. status_key_name (str): Status key name to check for. Raises: exceptions.CapacityError: If the training job fails with CapacityError. exceptions.UnexpectedStatusException: If the training job fails. """ status = desc[status_key_name] # If the status is capital case, then convert it to Camel case status = _STATUS_CODE_TABLE.get(status, status) if status == "Stopped": LOGGER.warning( "Job ended with status 'Stopped' rather than 'Completed'. " "This could mean the job timed out or stopped early for some other reason: " "Consider checking whether it completed as you expect." ) elif status != "Completed": reason = desc.get("FailureReason", "(No reason provided)") job_type = status_key_name.replace("JobStatus", " job") message = "Error for {job_type} {job_name}: {status}. Reason: {reason}".format( job_type=job_type, job_name=job, status=status, reason=reason ) if "CapacityError" in str(reason): raise exceptions.CapacityError( message=message, allowed_statuses=["Completed", "Stopped"], actual_status=status, ) raise exceptions.UnexpectedStatusException( message=message, allowed_statuses=["Completed", "Stopped"], actual_status=status, ) def _logs_init(boto_session, description, job): """Placeholder docstring""" if job == "Training": if "InstanceGroups" in description["ResourceConfig"]: instance_count = 0 for instanceGroup in description["ResourceConfig"]["InstanceGroups"]: instance_count += instanceGroup["InstanceCount"] else: instance_count = description["ResourceConfig"]["InstanceCount"] elif job == "Transform": instance_count = description["TransformResources"]["InstanceCount"] elif job == "Processing": instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] elif job == "AutoML": instance_count = 0 stream_names = [] # The list of log streams positions = {} # The current position in each stream, map of stream name -> position # Increase retries allowed (from default of 4), as we don't want waiting for a training job # to be interrupted by a transient exception. config = botocore.config.Config(retries={"max_attempts": 15}) client = boto_session.client("logs", config=config) log_group = "/aws/sagemaker/" + job + "Jobs" dot = False color_wrap = sagemaker.logs.ColorWrap() return instance_count, stream_names, positions, client, log_group, dot, color_wrap def _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap ): """Placeholder docstring""" if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list # may be dynamic until we have a stream for every instance. try: streams = client.describe_log_streams( logGroupName=log_group, logStreamNamePrefix=job_name + "/", orderBy="LogStreamName", limit=min(instance_count, 50), ) stream_names = [s["logStreamName"] for s in streams["logStreams"]] while "nextToken" in streams: streams = client.describe_log_streams( logGroupName=log_group, logStreamNamePrefix=job_name + "/", orderBy="LogStreamName", limit=50, ) stream_names.extend([s["logStreamName"] for s in streams["logStreams"]]) positions.update( [ (s, sagemaker.logs.Position(timestamp=0, skip=0)) for s in stream_names if s not in positions ] ) except ClientError as e: # On the very first training job run on an account, there's no log group until # the container starts logging, so ignore any errors thrown about that err = e.response.get("Error", {}) if err.get("Code", None) != "ResourceNotFoundException": raise if len(stream_names) > 0: if dot: print("") dot = False for idx, event in sagemaker.logs.multi_stream_iter( client, log_group, stream_names, positions ): color_wrap(idx, event["message"]) ts, count = positions[stream_names[idx]] if event["timestamp"] == ts: positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1) else: positions[stream_names[idx]] = sagemaker.logs.Position( timestamp=event["timestamp"], skip=1 ) else: dot = True print(".", end="") sys.stdout.flush() s3_input = deprecated_class(TrainingInput, "sagemaker.session.s3_input") ShuffleConfig = deprecated_class(ShuffleConfig, "sagemaker.session.ShuffleConfig")