# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import json import os import time import boto3 def wait_for_s3_object( s3_bucket, key, local_dir, local_prefix="", aws_account=None, aws_region=None, timeout=1200, limit=20, fetch_only=None, training_job_name=None, ): """ Keep polling s3 object until it is generated. Pulling down latest data to local directory with short key Arguments: s3_bucket (string): s3 bucket name key (string): key for s3 object local_dir (string): local directory path to save s3 object local_prefix (string): local prefix path append to the local directory aws_account (string): aws account of the s3 bucket aws_region (string): aws region where the repo is located timeout (int): how long to wait for the object to appear before giving up limit (int): maximum number of files to download fetch_only (lambda): a function to decide if this object should be fetched or not training_job_name (string): training job name to query job status Returns: A list of all downloaded files, as local filenames """ session = boto3.Session() aws_account = aws_account or session.client("sts").get_caller_identity()["Account"] aws_region = aws_region or session.region_name s3 = session.resource("s3") sagemaker = session.client("sagemaker") bucket = s3.Bucket(s3_bucket) objects = [] print("Waiting for s3://%s/%s..." % (s3_bucket, key), end="", flush=True) start_time = time.time() cnt = 0 while len(objects) == 0: objects = list(bucket.objects.filter(Prefix=key)) if fetch_only: objects = list(filter(fetch_only, objects)) if objects: continue print(".", end="", flush=True) time.sleep(5) cnt += 1 if cnt % 80 == 0: print("") if time.time() > start_time + timeout: raise FileNotFoundError( "S3 object s3://%s/%s never appeared after %d seconds" % (s3_bucket, key, timeout) ) if training_job_name: training_job_status = sagemaker.describe_training_job( TrainingJobName=training_job_name )["TrainingJobStatus"] if training_job_status == "Failed": raise RuntimeError( "Training job {} failed while waiting for S3 object s3://{}/{}".format( training_job_name, s3_bucket, key ) ) print("\n", end="", flush=True) if len(objects) > limit: print("Only downloading %d of %d files" % (limit, len(objects))) objects = objects[-limit:] fetched_files = [] for obj in objects: print("Downloading %s" % obj.key) local_path = os.path.join(local_dir, local_prefix, obj.key.split("/")[-1]) obj.Object().download_file(local_path) fetched_files.append(local_path) return fetched_files def get_execution_role(role_name="sagemaker", aws_account=None, aws_region=None): """ Create sagemaker execution role to perform sagemaker task Args: role_name (string): name of the role to be created aws_account (string): aws account of the ECR repo aws_region (string): aws region where the repo is located """ session = boto3.Session() aws_account = aws_account or session.client("sts").get_caller_identity()["Account"] aws_region = aws_region or session.region_name assume_role_policy_document = json.dumps( { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Principal": { "Service": ["sagemaker.amazonaws.com", "robomaker.amazonaws.com"] }, "Action": "sts:AssumeRole", } ], } ) client = session.client("iam") try: client.get_role(RoleName=role_name) except client.exceptions.NoSuchEntityException: client.create_role( RoleName=role_name, AssumeRolePolicyDocument=str(assume_role_policy_document) ) print("Created new sagemaker execution role: %s" % role_name) client.attach_role_policy( PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", RoleName=role_name ) return client.get_role(RoleName=role_name)["Role"]["Arn"] def wait_for_training_job_to_complete(job_name): sagemaker_client = boto3.client("sagemaker") sagemaker_client.get_waiter("training_job_completed_or_stopped").wait(TrainingJobName=job_name)