# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import from sagemaker import image_uris from tests.unit.sagemaker.image_uris import expected_uris import pytest ACCOUNTS = { "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", "us-west-2": "763104351884", } TRAINIUM_REGIONS = ACCOUNTS.keys() TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch" def _expected_trainium_framework_uri( framework, version, region="us-west-2", inference_tool="neuron" ): return expected_uris.neuron_framework_uri( "{}-neuron".format(framework), fw_version=version, py_version="py38", account=ACCOUNTS[region], region=region, inference_tool=inference_tool, ) def _test_trainium_framework_uris(framework, version): for region in TRAINIUM_REGIONS: uri = image_uris.retrieve( framework, region, instance_type="ml.trn1.xlarge", version=version ) expected = _expected_trainium_framework_uri( "{}-training".format(framework), version, region=region, inference_tool="neuron" ) assert expected == uri def test_trainium_pytorch(pytorch_neuron_version): _test_trainium_framework_uris("pytorch", pytorch_neuron_version) def _test_trainium_unsupported_framework(framework, framework_version): for region in TRAINIUM_REGIONS: with pytest.raises(ValueError) as error: image_uris.retrieve( framework, region, version=framework_version, instance_type="ml.trn1.xlarge" ) expectedErr = ( f"Unsupported framework: {framework}. Supported framework(s) for Trainium instances: " f"{TRAINIUM_ALLOWED_FRAMEWORKS}." ) assert expectedErr in str(error) def test_trainium_unsupported_framework(): _test_trainium_unsupported_framework("autogluon", "0.6.1")