import json import logging import boto3 import botocore from botocore.exceptions import ClientError from crhelper import CfnResource logger = logging.getLogger(__name__) sm = boto3.client("sagemaker") # cfnhelper makes it easier to implement a CloudFormation custom resource helper = CfnResource() # CFN Handlers def lambda_handler(event, context): helper(event, context) @helper.create @helper.update def create_handler(event, context): """ Called when CloudFormation custom resource sends the delete event """ create_training_job(event) @helper.delete def delete_handler(event, context): """ Training Jobs can not be deleted only stopped if running. """ training_job_name = get_training_job_name(event) stop_training_job(training_job_name) @helper.poll_create @helper.poll_update def poll_create(event, context): """ Return true if the resource has been created and false otherwise so CloudFormation polls again. """ training_job_name = get_training_job_name(event) logger.info("Polling for training job: %s", training_job_name) return is_training_job_ready(training_job_name) @helper.poll_delete def poll_delete(event, context): """ Return true if the resource has been stopped. """ training_job_name = get_training_job_name(event) logger.info("Polling for stopped training job: %s", training_job_name) return stop_training_job(training_job_name) # Helper Functions def get_training_job_name(event): return event["ResourceProperties"]["TrainingJobName"] def is_training_job_ready(training_job_name): is_ready = False response = sm.describe_training_job(TrainingJobName=training_job_name) status = response["TrainingJobStatus"] if status == "Completed": logger.info("Training Job (%s) is Completed", training_job_name) # Return additional info helper.Data["TrainingJobName"] = training_job_name helper.Data["Arn"] = response["TrainingJobArn"] is_ready = True elif status == "InProgress" or status == "Stopping": logger.info( "Training job (%s) still in progress (%s), waiting and polling again...", training_job_name, response["SecondaryStatus"], ) else: raise Exception( "Training job ({}) has unexpected status: {}".format(training_job_name, status) ) return is_ready def create_training_job(event): training_job_name = get_training_job_name(event) request = get_training_request(event) logger.info("Creating training job with name: %s", training_job_name) logger.debug(json.dumps(request)) response = sm.create_training_job(**request) # Update Output Parameters helper.Data["TrainingJobName"] = training_job_name helper.Data["Arn"] = response["TrainingJobArn"] return helper.Data["Arn"] # TODO: Test to see what Validation/Resource not found errors are returned for training jobs def stop_training_job(training_job_name): try: training_job = sm.describe_training_job(TrainingJobName=training_job_name) status = training_job["TrainingJobStatus"] if status == "InProgress": logger.info("Stopping InProgress training job: %s", training_job_name) sm.stop_training_job(TrainingJobName=training_job_name) return False else: logger.info("Training job status: %s, nothing to stop", status) return True except ClientError as e: # NOTE: This doesn't return "ResourceNotFound" code, so need to catch if ( e.response["Error"]["Code"] == "ValidationException" and "resource not found" in e.response["Error"]["Message"] ): logger.info("Resource not found, nothing to stop") return True else: logger.error("Unexpected error while trying to stop training job") raise e def get_training_request(event): props = event["ResourceProperties"] # Load raw request request = json.loads(props["TrainingJobRequest"]) # Add the KmsKeyId to monitoring outputs and cluster volume if provided if props.get("KmsKeyId") is not None: request["ResourceConfig"]["VolumeKmsKeyId"] = props["KmsKeyId"] # Set the training job name request["TrainingJobName"] = props["TrainingJobName"] # Add experiment tracking request["ExperimentConfig"] = { "ExperimentName": props["ExperimentName"], "TrialName": props["TrialName"], "TrialComponentDisplayName": "Training", } return request