# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import json import os import boto3 import pytest from botocore.config import Config from sagemaker import Session, utils DEFAULT_REGION = "us-west-2" CUSTOM_BUCKET_NAME_PREFIX = "smjsindustry-custom-bucket" NO_M4_REGIONS = [ "eu-west-3", "eu-north-1", "ap-east-1", "ap-northeast-1", # it has m4.xl, but not enough in all AZs "sa-east-1", "me-south-1", ] def pytest_addoption(parser): parser.addoption("--sagemaker-client-config", action="store", default=None) parser.addoption("--sagemaker-runtime-config", action="store", default=None) parser.addoption("--boto-config", action="store", default=None) def pytest_configure(config): bc = config.getoption("--boto-config") parsed = json.loads(bc) if bc else {} region = parsed.get("region_name", boto3.session.Session().region_name) if region: os.environ["TEST_AWS_REGION_NAME"] = region @pytest.fixture(scope="session") def sagemaker_client_config(request): config = request.config.getoption("--sagemaker-client-config") return json.loads(config) if config else dict() @pytest.fixture(scope="session") def sagemaker_runtime_config(request): config = request.config.getoption("--sagemaker-runtime-config") return json.loads(config) if config else None @pytest.fixture(scope="session") def boto_session(request): config = request.config.getoption("--boto-config") if config: return boto3.Session(**json.loads(config)) else: return boto3.Session(region_name=DEFAULT_REGION) @pytest.fixture(scope="session") def account(boto_session): return boto_session.client("sts").get_caller_identity()["Account"] @pytest.fixture(scope="session") def region(boto_session): return boto_session.region_name @pytest.fixture(scope="session") def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_session): sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) sagemaker_client = ( boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None ) runtime_client = ( boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) if sagemaker_runtime_config else None ) return Session( boto_session=boto_session, sagemaker_client=sagemaker_client, sagemaker_runtime_client=runtime_client, ) @pytest.fixture(scope="module") def custom_bucket_name(boto_session): region = boto_session.region_name account = boto_session.client( "sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region) ).get_caller_identity()["Account"] return "{}-{}-{}".format(CUSTOM_BUCKET_NAME_PREFIX, region, account) @pytest.fixture(scope="session") def cpu_instance_type(sagemaker_session, request): region = sagemaker_session.boto_session.region_name if region in NO_M4_REGIONS: return "ml.m5.xlarge" else: return "ml.m4.xlarge"