# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import pytest from mock import MagicMock, Mock, patch from packaging.version import Version from sagemaker import image_uris from sagemaker.tensorflow import TensorFlow BUCKET_NAME = "mybucket" LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} REGION = "us-west-2" ROLE = "Dummy" @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION, config=None, local_mode=False, s3_resource=None, s3_client=None, default_bucket_prefix=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session @patch("sagemaker.utils.create_tar_file", MagicMock()) def test_attach(sagemaker_session, tensorflow_training_version, tensorflow_training_py_version): if Version(tensorflow_training_version) > Version("1.12"): pytest.skip("framework_name_from_image doesn't infer info from DLC image URIs.") training_image = image_uris.retrieve( "tensorflow", region=REGION, version=tensorflow_training_version, py_version=tensorflow_training_py_version, instance_type="ml.c4.xlarge", image_scope="training", ) rjd = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, "HyperParameters": { "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', "sagemaker_program": '"iris-dnn-classifier.py"', "sagemaker_container_log_level": '"logging.INFO"', "sagemaker_job_name": '"neo"', }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", "ResourceConfig": { "VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", }, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", "TrainingJobStatus": "Completed", "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } sagemaker_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=rjd ) estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == tensorflow_training_py_version assert estimator.framework_version == tensorflow_training_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.instance_count == 1 assert estimator.max_run == 24 * 60 * 60 assert estimator.input_mode == "File" assert estimator.input_mode == "File" assert estimator.base_job_name == "neo" assert estimator.output_path == "s3://place/output/neo" assert estimator.output_kms_key == "" assert estimator.hyperparameters() is not None assert estimator.source_dir == "s3://some/sourcedir.tar.gz" assert estimator.entry_point == "iris-dnn-classifier.py" assert estimator.training_image_uri() == training_image @patch("sagemaker.utils.create_tar_file", MagicMock()) def test_attach_old_container(sagemaker_session): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0" rjd = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, "HyperParameters": { "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', "sagemaker_program": '"iris-dnn-classifier.py"', "sagemaker_container_log_level": '"logging.INFO"', "sagemaker_job_name": '"neo"', }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", "ResourceConfig": { "VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", }, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", "TrainingJobStatus": "Completed", "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } sagemaker_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=rjd ) estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == "py2" assert estimator.framework_version == "1.4" assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.instance_count == 1 assert estimator.max_run == 24 * 60 * 60 assert estimator.input_mode == "File" assert estimator.input_mode == "File" assert estimator.base_job_name == "neo" assert estimator.output_path == "s3://place/output/neo" assert estimator.output_kms_key == "" assert estimator.source_dir == "s3://some/sourcedir.tar.gz" assert estimator.entry_point == "iris-dnn-classifier.py" def test_attach_wrong_framework(sagemaker_session): returned_job_description = { "AlgorithmSpecification": { "TrainingInputMode": "File", "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0", }, "HyperParameters": { "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', "sagemaker_program": '"iris-dnn-classifier.py"', "sagemaker_container_log_level": '"logging.INFO"', }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", "ResourceConfig": { "VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", }, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", "TrainingJobStatus": "Completed", "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } sagemaker_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=returned_job_description ) with pytest.raises(ValueError) as error: TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/tensorflow_with_custom_binary:1.0" rjd = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, "HyperParameters": { "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', "sagemaker_program": '"iris-dnn-classifier.py"', "sagemaker_container_log_level": '"logging.INFO"', "sagemaker_job_name": '"neo"', }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", "ResourceConfig": { "VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge", }, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", "TrainingJobStatus": "Completed", "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } sagemaker_session.sagemaker_client.describe_training_job = Mock( name="describe_training_job", return_value=rjd ) estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_uri == training_image assert estimator.training_image_uri() == training_image