import datetime import os import subprocess import sys import time subprocess.check_call(["pip", "install", "--upgrade", "botocore", "--target", "/tmp/"]) subprocess.check_call(["pip", "install", "--upgrade", "boto3", "--target", "/tmp/"]) sys.path.insert(0, "/tmp/") import boto3 # pylint: disable=wrong-import-position import botocore import common from botocore.exceptions import ClientError logger = common.get_logger() DEFAULT_SLEEP_TIME_SECONDS = 10 LIST_MAX_RESULT_COUNT = 100 # lambda sets these as environment variables MAXIMUM_ENDPOINT_AGE = os.environ["MAX_ENDPOINT_AGE_IN_MINUTES"] REGION = os.environ["AWS_REGION"] def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS): """Retries until max retry count is reached. Args: max_retry_count (int): The retry count. exception_message_prefix (str): The message to include in the exception on failure. seconds_to_sleep (int): The number of seconds to sleep between executions. """ for i in range(max_retry_count): yield i time.sleep(seconds_to_sleep) raise Exception( "'{}' has reached the maximum retry count of {}".format( exception_message_prefix, max_retry_count ) ) def get_resources(client, next_token, before_timestamp, resource_type): list_req = {"MaxResults": LIST_MAX_RESULT_COUNT, "CreationTimeBefore": before_timestamp} if next_token: list_req.update({"NextToken": next_token}) if resource_type == "MonitoringSchedules": response = client.list_monitoring_schedules(**list_req) resource_type = "MonitoringScheduleSummaries" elif resource_type == "ProcessingJobs": response = client.list_processing_jobs(**list_req) resource_type = "ProcessingJobSummaries" elif resource_type == "Endpoints": list_req.update({"StatusEquals": "InService"}) response = client.list_endpoints(**list_req) elif resource_type == "EndpointConfigs": response = client.list_endpoint_configs(**list_req) elif resource_type == "Experiments": response = client.list_experiments( MaxResults=LIST_MAX_RESULT_COUNT, CreatedBefore=before_timestamp ) resource_type = "ExperimentSummaries" resources = response[resource_type] new_next = response.get("NextToken", None) return resources, new_next def stop_resources(client, resource_names, resource_type): stopped_all_resources = True for resource_name in resource_names: logger.info("Stopping %s", resource_name) if resource_type == "MonitoringSchedules": try: client.stop_monitoring_schedule(MonitoringScheduleName=resource_name) except Exception: # pylint: disable=broad-except logger.exception("Unable to stop monitoring schedule") stopped_all_resources = False for _ in retries(60, "Waiting for Monitoring Schedules to stop", seconds_to_sleep=5): status = client.describe_monitoring_schedule(MonitoringScheduleName=resource_name)[ "MonitoringScheduleStatus" ] if status in {"Stopped", "Failed", "Completed"}: break if resource_type == "ProcessingJobs": try: client.stop_processing_job(ProcessingJobName=resource_name) except Exception: # pylint: disable=broad-except logger.exception("Unable to stop processing job") stopped_all_resources = False for _ in retries(60, "Waiting for Processing Jobs to stop", seconds_to_sleep=5): status = client.describe_processing_job(ProcessingJobName=resource_name)[ "ProcessingJobStatus" ] if status in {"Stopped", "Failed", "Completed"}: break logger.info("Stopped %s", resource_name) time.sleep(0.5) return stopped_all_resources def batch_stop_resources(client, before_timestamp, resource_type): try: next_token = None more = True while more: logger.info("Searching for %s from before: %s", resource_type.lower(), before_timestamp) resources, next_token = get_resources( client, next_token, before_timestamp, resource_type ) logger.info("Found %s items, stopping now", len(resources)) if resource_type == "MonitoringSchedules": resource_names = [resource["MonitoringScheduleName"] for resource in resources] elif resource_type == "ProcessingJobs": resource_names = [resource["ProcessingJobName"] for resource in resources] stopped = stop_resources(client, resource_names, resource_type) more = stopped and (next_token is not None) time.sleep(1) finally: logger.info( "Finished cleaning %s at %s", resource_type.lower(), str(datetime.datetime.now()) ) def delete_experiment(client, experiment_name): trials = client.list_trials(ExperimentName=experiment_name) for trial in trials["TrialSummaries"]: trial_name = trial["TrialName"] trial_components = client.list_trial_components(TrialName=trial_name) for tc in trial_components["TrialComponentSummaries"]: tc_name = tc["TrialComponentName"] client.disassociate_trial_component( TrialName=trial_name, TrialComponentName=tc_name, ) try: client.delete_trial_component(TrialComponentName=tc_name) except botocore.exceptions.ClientError as e: # If the trial component is still linked to another trial, # ignore it for now. (It will get deleted once completely unlinked.) linked_tc_msg = ( "An error occurred (ValidationException) when calling the " "DeleteTrialComponent operation: TrialComponent %s is linked " "to 1 or more trials and cannot be deleted." ) % tc_name if linked_tc_msg in str(e): pass else: raise client.delete_trial(TrialName=trial_name) client.delete_experiment(ExperimentName=experiment_name) def delete_resources(client, resource_names, resource_type): for resource_name in resource_names: logger.info("Deleting %s", resource_name) if resource_type == "MonitoringSchedules": client.delete_monitoring_schedule(MonitoringScheduleName=resource_name) elif resource_type == "Endpoints": client.delete_endpoint(EndpointName=resource_name) elif resource_type == "EndpointConfigs": client.delete_endpoint_config(EndpointConfigName=resource_name) elif resource_type == "Experiments": delete_experiment(client, resource_name) logger.info("Deleted %s", resource_name) time.sleep(0.5) def batch_delete_resources(client, before_timestamp, resource_type): try: next_token = None more = True while more: logger.info("Searching for %s from before: %s", resource_type.lower(), before_timestamp) resources, next_token = get_resources( client, next_token, before_timestamp, resource_type ) logger.info("Found %s items, deleting now", len(resources)) if resource_type == "MonitoringSchedules": resource_names = [resource["MonitoringScheduleName"] for resource in resources] elif resource_type == "Endpoints": resource_names = [resource["EndpointName"] for resource in resources] elif resource_type == "EndpointConfigs": resource_names = [resource["EndpointConfigName"] for resource in resources] elif resource_type == "Experiments": resource_names = [resource["ExperimentName"] for resource in resources] delete_resources(client, resource_names, resource_type) more = next_token is not None time.sleep(1) finally: logger.info( "Finished cleaning %s at %s", resource_type.lower(), str(datetime.datetime.now()) ) def lambda_handler(event, context): # pylint: disable=unused-argument logger.info("Invoking endpoint cleanup at %s...", event["time"]) before_timestamp = datetime.datetime.now() - datetime.timedelta( minutes=int(MAXIMUM_ENDPOINT_AGE) ) # Get all regions available for SageMaker. sagemaker_regions = boto3.Session().get_available_regions("sagemaker") # Clean up resources in each region. for region in sagemaker_regions: try: sm_client = boto3.Session(region_name=region).client("sagemaker") # cleaning schedules batch_stop_resources(sm_client, before_timestamp, "MonitoringSchedules") batch_stop_resources(sm_client, before_timestamp, "ProcessingJobs") batch_delete_resources(sm_client, before_timestamp, "MonitoringSchedules") # cleaning endpoints batch_delete_resources(sm_client, before_timestamp, "Endpoints") # cleaning endpoint_configs batch_delete_resources(sm_client, before_timestamp, "EndpointConfigs") # cleaning trials and trial components batch_delete_resources(sm_client, before_timestamp, "Experiments") except ClientError as e: logger.debug("ERROR in region %s: %s", region, str(e))