# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import inspect
import json
import os
import tarfile
from contextlib import contextmanager
from itertools import product

import pytest

from mock import Mock, patch

from sagemaker import fw_utils
from sagemaker.utils import name_from_image
from sagemaker.session_settings import SessionSettings
from sagemaker.instance_group import InstanceGroup

TIMESTAMP = "2017-10-10-14-14-15"


@contextmanager
def cd(path):
    old_dir = os.getcwd()
    os.chdir(path)
    yield
    os.chdir(old_dir)


@pytest.fixture()
def sagemaker_session():
    boto_mock = Mock(name="boto_session", region_name="us-west-2")
    session_mock = Mock(
        name="sagemaker_session",
        boto_session=boto_mock,
        s3_client=None,
        s3_resource=None,
        settings=SessionSettings(),
        default_bucket_prefix=None,
    )
    session_mock.default_bucket = Mock(name="default_bucket", return_value="my-bucket")
    session_mock.expand_role = Mock(name="expand_role", return_value="my-role")
    session_mock.sagemaker_client.describe_training_job = Mock(
        return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
    )
    return session_mock


def test_tar_and_upload_dir_s3(sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = "mnist.py"
    directory = "s3://m"
    result = fw_utils.tar_and_upload_dir(
        sagemaker_session, bucket, s3_key_prefix, script, directory
    )

    assert result == fw_utils.UploadedCode("s3://m", "mnist.py")


def test_tar_and_upload_dir_s3_with_script_dir(sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = "some/dir/mnist.py"
    directory = "s3://m"
    result = fw_utils.tar_and_upload_dir(
        sagemaker_session, bucket, s3_key_prefix, script, directory
    )

    assert result == fw_utils.UploadedCode("s3://m", "some/dir/mnist.py")


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = "mnist.py"
    kms_key = "kms-key"
    result = fw_utils.tar_and_upload_dir(
        sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key
    )

    assert result == fw_utils.UploadedCode(
        "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
    )

    extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
    obj = sagemaker_session.resource("s3").Object("", "")
    obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_kms_enabled_by_default(utils, sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = "inference.py"
    result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script)

    assert result == fw_utils.UploadedCode(
        "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
    )

    extra_args = {"ServerSideEncryption": "aws:kms"}
    obj = sagemaker_session.resource("s3").Object("", "")
    obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_without_kms_with_overridden_settings(utils, sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = "inference.py"
    settings = SessionSettings(encrypt_repacked_artifacts=False)
    result = fw_utils.tar_and_upload_dir(
        sagemaker_session, bucket, s3_key_prefix, script, settings=settings
    )

    assert result == fw_utils.UploadedCode(
        "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
    )

    obj = sagemaker_session.resource("s3").Object("", "")
    obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)


def test_mp_config_partition_exists():
    mp_parameters = {}
    with pytest.raises(ValueError):
        fw_utils.validate_mp_config(mp_parameters)


@pytest.mark.parametrize(
    "pipeline, placement_strategy, optimize, trace_device",
    [
        ("simple", "spread", "speed", "cpu"),
        ("interleaved", "cluster", "memory", "gpu"),
        ("_only_forward", "spread", "speed", "gpu"),
    ],
)
def test_mp_config_string_names(pipeline, placement_strategy, optimize, trace_device):
    mp_parameters = {
        "partitions": 2,
        "pipeline": pipeline,
        "placement_strategy": placement_strategy,
        "optimize": optimize,
        "trace_device": trace_device,
        "active_microbatches": 8,
        "deterministic_server": True,
    }
    fw_utils.validate_mp_config(mp_parameters)


def test_mp_config_auto_partition_arg():
    mp_parameters = {}
    mp_parameters["partitions"] = 2
    mp_parameters["auto_partition"] = False
    with pytest.raises(ValueError):
        fw_utils.validate_mp_config(mp_parameters)

    mp_parameters["default_partition"] = 1
    fw_utils.validate_mp_config(mp_parameters)

    mp_parameters["default_partition"] = 4
    with pytest.raises(ValueError):
        fw_utils.validate_mp_config(mp_parameters)


def test_validate_source_dir_does_not_exits(sagemaker_session):
    script = "mnist.py"
    directory = " !@#$%^&*()path probably in not there.!@#$%^&*()"
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)


def test_validate_source_dir_is_not_directory(sagemaker_session):
    script = "mnist.py"
    directory = inspect.getfile(inspect.currentframe())
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)


def test_validate_source_dir_file_not_in_dir():
    script = " !@#$%^&*() .myscript. !@#$%^&*() "
    directory = "."
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)


def test_parse_mp_parameters_input_dict():
    mp_parameters = {
        "partitions": 1,
        "tensor_parallel_degree": 2,
        "microbatches": 1,
        "optimize": "speed",
        "pipeline": "interleaved",
        "ddp": 1,
        "auto_partition": False,
        "default_partition": 0,
    }
    assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)


def test_parse_mp_parameters_input_str_json():
    mp_parameters = {
        "partitions": 1,
        "tensor_parallel_degree": 2,
        "microbatches": 1,
        "optimize": "speed",
        "pipeline": "interleaved",
        "ddp": 1,
        "auto_partition": False,
        "default_partition": 0,
    }
    json_file_path = "./params.json"
    with open(json_file_path, "x") as fp:
        json.dump(mp_parameters, fp)
    assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
    os.remove(json_file_path)


def test_parse_mp_parameters_input_not_exit():
    with pytest.raises(ValueError):
        fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")


def test_tar_and_upload_dir_not_s3(sagemaker_session):
    bucket = "mybucket"
    s3_key_prefix = "something/source"
    script = os.path.basename(__file__)
    directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    result = fw_utils.tar_and_upload_dir(
        sagemaker_session, bucket, s3_key_prefix, script, directory
    )
    assert result == fw_utils.UploadedCode(
        "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
    )


def file_tree(tmpdir, files=None, folders=None):
    files = files or []
    folders = folders or []
    for file in files:
        tmpdir.join(file).ensure(file=True)

    for folder in folders:
        tmpdir.join(folder).ensure(dir=True)

    return str(tmpdir)


def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir):
    source_dir = file_tree(tmpdir, ["train.py"])
    entrypoint = os.path.join(source_dir, "train.py")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", entrypoint, None
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir)


def test_tar_and_upload_dir_no_directory_only_entrypoint(sagemaker_session, tmpdir):
    source_dir = file_tree(tmpdir, ["train.py", "not_me.py"])
    entrypoint = os.path.join(source_dir, "train.py")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", entrypoint, None
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir)


def test_tar_and_upload_dir_no_directory_bare_filename(sagemaker_session, tmpdir):
    source_dir = file_tree(tmpdir, ["train.py"])
    entrypoint = "train.py"

    with patch("shutil.rmtree"):
        with cd(source_dir):
            result = fw_utils.tar_and_upload_dir(
                sagemaker_session, "bucket", "prefix", entrypoint, None
            )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir)


def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["src-dir/train.py"])
    source_dir = os.path.join(str(tmpdir), "src-dir")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", "train.py", source_dir
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir)


def test_tar_and_upload_dir_with_subdirectory(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["src-dir/sub/train.py"])
    source_dir = os.path.join(str(tmpdir), "src-dir")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", "train.py", source_dir
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/sub/train.py"} == list_source_dir_files(sagemaker_session, tmpdir)


def test_tar_and_upload_dir_with_directory_and_files(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["src-dir/train.py", "src-dir/laucher", "src-dir/module/__init__.py"])
    source_dir = os.path.join(str(tmpdir), "src-dir")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", "train.py", source_dir
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py"
    )

    assert {"/laucher", "/module/__init__.py", "/train.py"} == list_source_dir_files(
        sagemaker_session, tmpdir
    )


def test_tar_and_upload_dir_with_directories_and_files(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["src-dir/a/b", "src-dir/a/b2", "src-dir/x/y", "src-dir/x/y2", "src-dir/z"])
    source_dir = os.path.join(str(tmpdir), "src-dir")

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", "a/b", source_dir
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="a/b"
    )

    assert {"/a/b", "/a/b2", "/x/y", "/x/y2", "/z"} == list_source_dir_files(
        sagemaker_session, tmpdir
    )


def test_tar_and_upload_dir_with_many_folders(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["src-dir/a/b", "src-dir/a/b2", "common/x/y", "common/x/y2", "t/y/z"])
    source_dir = os.path.join(str(tmpdir), "src-dir")
    dependencies = [os.path.join(str(tmpdir), "common"), os.path.join(str(tmpdir), "t", "y", "z")]

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session, "bucket", "prefix", "pipeline.py", source_dir, dependencies
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="pipeline.py"
    )

    assert {"/a/b", "/a/b2", "/common/x/y", "/common/x/y2", "/z"} == list_source_dir_files(
        sagemaker_session, tmpdir
    )


def test_test_tar_and_upload_dir_with_subfolders(sagemaker_session, tmpdir):
    file_tree(tmpdir, ["a/b/c", "a/b/c2"])
    root = file_tree(tmpdir, ["x/y/z", "x/y/z2"])

    with patch("shutil.rmtree"):
        result = fw_utils.tar_and_upload_dir(
            sagemaker_session,
            "bucket",
            "prefix",
            "b/c",
            os.path.join(root, "a"),
            [os.path.join(root, "x")],
        )

    assert result == fw_utils.UploadedCode(
        s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="b/c"
    )

    assert {"/b/c", "/b/c2", "/x/y/z", "/x/y/z2"} == list_source_dir_files(
        sagemaker_session, tmpdir
    )


def list_source_dir_files(sagemaker_session, tmpdir):
    source_dir_tar = sagemaker_session.resource("s3").Object().upload_file.call_args[0][0]

    source_dir_files = list_tar_files("/opt/ml/code/", source_dir_tar, tmpdir)
    return source_dir_files


def list_tar_files(folder, tar_ball, tmpdir):
    startpath = str(tmpdir.ensure(folder, dir=True))

    with tarfile.open(name=tar_ball, mode="r:gz") as t:
        t.extractall(path=startpath)

    def walk():
        for root, dirs, files in os.walk(startpath):
            path = root.replace(startpath, "")
            for f in files:
                yield "%s/%s" % (path, f)

    result = set(walk())
    return result if result else {}


def test_framework_name_from_image_mxnet():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3"
    assert ("mxnet", "py3", "1.1-gpu-py3", None) == fw_utils.framework_name_from_image(image_uri)


def test_framework_name_from_image_mxnet_in_gov():
    image_uri = "123.dkr.ecr.region-name.c2s.ic.gov/sagemaker-mxnet:1.1-gpu-py3"
    assert ("mxnet", "py3", "1.1-gpu-py3", None) == fw_utils.framework_name_from_image(image_uri)


def test_framework_name_from_image_tf():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2"
    assert ("tensorflow", "py2", "1.6-cpu-py2", None) == fw_utils.framework_name_from_image(
        image_uri
    )


def test_framework_name_from_image_tf_scriptmode():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-cpu-py3"
    assert (
        "tensorflow",
        "py3",
        "1.12-cpu-py3",
        "scriptmode",
    ) == fw_utils.framework_name_from_image(image_uri)

    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13-cpu-py3"
    assert ("tensorflow", "py3", "1.13-cpu-py3", "training") == fw_utils.framework_name_from_image(
        image_uri
    )


def test_framework_name_from_image_rl():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3"
    assert ("mxnet", "py3", "toolkit1.1-gpu-py3", None) == fw_utils.framework_name_from_image(
        image_uri
    )


def test_framework_name_from_image_python_versions():
    image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.2-cpu-py37"
    assert ("tensorflow", "py37", "2.2-cpu-py37", "training") == fw_utils.framework_name_from_image(
        image_name
    )

    image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.15.2-cpu-py36"
    expected_result = ("tensorflow", "py36", "1.15.2-cpu-py36", "training")
    assert expected_result == fw_utils.framework_name_from_image(image_name)


def test_legacy_name_from_framework_image():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2"
    framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_uri)
    assert framework == "mxnet"
    assert py_ver == "py3"
    assert tag == "2.5.6-gpu-py2"


def test_legacy_name_from_wrong_framework():
    framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
        "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1"
    )
    assert framework is None
    assert py_ver is None
    assert tag is None


def test_legacy_name_from_wrong_python():
    framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
        "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1"
    )
    assert framework is None
    assert py_ver is None
    assert tag is None


def test_legacy_name_from_wrong_device():
    framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
        "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1"
    )
    assert framework is None
    assert py_ver is None
    assert tag is None


def test_legacy_name_from_image_any_tag():
    image_uri = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:any-tag"
    framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_uri)
    assert framework == "tensorflow"
    assert py_ver == "py2"
    assert tag == "any-tag"


def test_framework_version_from_tag():
    tags = (
        "1.5rc-keras-cpu-py2",
        "1.5rc-keras-gpu-py2",
        "1.5rc-keras-cpu-py3",
        "1.5rc-keras-gpu-py36",
        "1.5rc-keras-gpu-py37",
    )

    for tag in tags:
        version = fw_utils.framework_version_from_tag(tag)
        assert "1.5rc-keras" == version


def test_framework_version_from_tag_other():
    version = fw_utils.framework_version_from_tag("weird-tag-py2")
    assert version is None


def test_xgboost_version_from_tag():
    tags = (
        "1.5-1-cpu-py3",
        "1.5-1",
    )

    for tag in tags:
        version = fw_utils.framework_version_from_tag(tag)
        assert "1.5-1" == version


def test_framework_name_from_xgboost_image_short_tag():
    ecr_uri = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost"
    image_tag = "1.5-1"
    image_uri = f"{ecr_uri}:{image_tag}"
    expected_result = ("xgboost", "py3", "1.5-1", None)
    assert expected_result == fw_utils.framework_name_from_image(image_uri)


def test_framework_name_from_xgboost_image_long_tag():
    ecr_uri = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost"
    image_tag = "1.5-1-cpu-py3"
    image_uri = f"{ecr_uri}:{image_tag}"
    expected_result = ("xgboost", "py3", "1.5-1-cpu-py3", None)
    assert expected_result == fw_utils.framework_name_from_image(image_uri)


def test_model_code_key_prefix_with_all_values_present():
    key_prefix = fw_utils.model_code_key_prefix("prefix", "model_name", "image_uri")
    assert key_prefix == "prefix/model_name"


def test_model_code_key_prefix_with_no_prefix_and_all_other_values_present():
    key_prefix = fw_utils.model_code_key_prefix(None, "model_name", "image_uri")
    assert key_prefix == "model_name"


@patch("time.strftime", return_value=TIMESTAMP)
def test_model_code_key_prefix_with_only_image_present(time):
    key_prefix = fw_utils.model_code_key_prefix(None, None, "image_uri")
    assert key_prefix == name_from_image("image_uri")


@patch("time.strftime", return_value=TIMESTAMP)
def test_model_code_key_prefix_and_image_present(time):
    key_prefix = fw_utils.model_code_key_prefix("prefix", None, "image_uri")
    assert key_prefix == "prefix/" + name_from_image("image_uri")


def test_model_code_key_prefix_with_prefix_present_and_others_none_fail():
    with pytest.raises(TypeError) as error:
        fw_utils.model_code_key_prefix("prefix", None, None)
    assert "expected string" in str(error.value)


def test_model_code_key_prefix_with_all_none_fail():
    with pytest.raises(TypeError) as error:
        fw_utils.model_code_key_prefix(None, None, None)
    assert "expected string" in str(error.value)


def test_region_supports_debugger_feature_returns_true_for_supported_regions():
    assert fw_utils._region_supports_debugger("us-west-2") is True
    assert fw_utils._region_supports_debugger("us-east-2") is True


def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
    assert fw_utils._region_supports_debugger("us-iso-east-1") is False
    assert fw_utils._region_supports_debugger("us-isob-east-1") is False
    assert fw_utils._region_supports_debugger("ap-southeast-3") is False
    assert fw_utils._region_supports_debugger("ap-southeast-4") is False
    assert fw_utils._region_supports_debugger("eu-south-2") is False
    assert fw_utils._region_supports_debugger("me-central-1") is False
    assert fw_utils._region_supports_debugger("ap-south-2") is False
    assert fw_utils._region_supports_debugger("eu-central-2") is False
    assert fw_utils._region_supports_debugger("us-gov-east-1") is False


def test_warn_if_parameter_server_with_multi_gpu(caplog):
    instance_type = "ml.p2.8xlarge"
    distribution = {"parameter_server": {"enabled": True}}

    fw_utils.warn_if_parameter_server_with_multi_gpu(
        training_instance_type=instance_type, distribution=distribution
    )
    assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text


def test_warn_if_parameter_server_with_local_multi_gpu(caplog):
    instance_type = "local_gpu"
    distribution = {"parameter_server": {"enabled": True}}

    fw_utils.warn_if_parameter_server_with_multi_gpu(
        training_instance_type=instance_type, distribution=distribution
    )
    assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text


def test_validate_version_or_image_args_not_raises():
    good_args = [("1.0", "py3", None), (None, "py3", "my:uri"), ("1.0", None, "my:uri")]
    for framework_version, py_version, image_uri in good_args:
        fw_utils.validate_version_or_image_args(framework_version, py_version, image_uri)


def test_validate_version_or_image_args_raises():
    bad_args = [(None, None, None), (None, "py3", None), ("1.0", None, None)]
    for framework_version, py_version, image_uri in bad_args:
        with pytest.raises(ValueError):
            fw_utils.validate_version_or_image_args(framework_version, py_version, image_uri)


def test_validate_distribution_not_raises():
    train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1)
    other_group = InstanceGroup("other_group", "ml.p3.16xlarge", 1)
    instance_groups = [train_group, other_group]

    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
    smdataparallel_enabled_custom_mpi = {
        "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
    }
    smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
    mpi_enabled = {"mpi": {"enabled": True, "processes_per_host": 2}}
    mpi_disabled = {"mpi": {"enabled": False}}

    instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)

    good_args_normal = [
        smdataparallel_enabled,
        smdataparallel_enabled_custom_mpi,
        smdataparallel_disabled,
        mpi_enabled,
        mpi_disabled,
    ]

    frameworks = ["tensorflow", "pytorch"]

    for framework, instance_type in product(frameworks, instance_types):
        for distribution in good_args_normal:
            fw_utils.validate_distribution(
                distribution,
                None,  # instance_groups
                framework,
                None,  # framework_version
                None,  # py_version
                "custom-container",
                {"instance_type": instance_type, "entry_point": "train.py"},  # kwargs
            )

    for framework in frameworks:
        good_args_hc = [
            {
                "smdistributed": {"dataparallel": {"enabled": True}},
                "instance_groups": [train_group],
            },  # smdataparallel_enabled_hc
            {
                "mpi": {"enabled": True, "processes_per_host": 2},
                "instance_groups": [train_group],
            },  # mpi_enabled_hc
            {
                "smdistributed": {
                    "dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"},
                },
                "instance_groups": [train_group],
            },  # smdataparallel_enabled_custom_mpi_hc
        ]
        for distribution in good_args_hc:
            fw_utils.validate_distribution(
                distribution,
                instance_groups,  # instance_groups
                framework,
                None,  # framework_version
                None,  # py_version
                "custom-container",
                {"entry_point": "train.py"},  # kwargs
            )


def test_validate_distribution_raises():
    train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1)
    other_group = InstanceGroup("other_group", "ml.p3.16xlarge", 1)
    dummy_group = InstanceGroup("dummy_group", "ml.p3.16xlarge", 1)
    instance_groups = [train_group, other_group, dummy_group]

    mpi_enabled_hc = {
        "mpi": {"enabled": True, "processes_per_host": 2},
        "instance_groups": [train_group, other_group],
    }
    smdataparallel_enabled_hc = {
        "smdistributed": {"dataparallel": {"enabled": True}},
        "instance_groups": [],
    }

    instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)

    bad_args_normal = [
        {"smdistributed": "dummy"},
        {"smdistributed": {"dummy"}},
        {"smdistributed": {"dummy": "val"}},
        {"smdistributed": {"dummy": {"enabled": True}}},
    ]
    bad_args_hc = [mpi_enabled_hc, smdataparallel_enabled_hc]
    frameworks = ["tensorflow", "pytorch"]

    for framework, instance_type in product(frameworks, instance_types):
        for distribution in bad_args_normal:
            with pytest.raises(ValueError):
                fw_utils.validate_distribution(
                    distribution,
                    None,  # instance_groups
                    framework,
                    None,  # framework_version
                    None,  # py_version
                    "custom-container",
                    {"instance_type": instance_type, "entry_point": "train.py"},  # kwargs
                )

    for framework in frameworks:
        for distribution in bad_args_hc:
            with pytest.raises(ValueError):
                fw_utils.validate_distribution(
                    distribution,
                    instance_groups,  # instance_groups
                    framework,
                    None,  # framework_version
                    None,  # py_version
                    "custom-container",
                    {},  # kwargs
                )


def test_validate_smdistributed_not_raises():
    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
    smdataparallel_enabled_custom_mpi = {
        "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
    }
    smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
    instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)

    good_args = [
        (smdataparallel_enabled, "custom-container"),
        (smdataparallel_enabled_custom_mpi, "custom-container"),
        (smdataparallel_disabled, "custom-container"),
    ]
    frameworks = ["tensorflow", "pytorch"]

    for framework, instance_type in product(frameworks, instance_types):
        for distribution, image_uri in good_args:
            fw_utils.validate_smdistributed(
                instance_type=instance_type,
                framework_name=framework,
                framework_version=None,
                py_version=None,
                distribution=distribution,
                image_uri=image_uri,
            )


def test_validate_smdistributed_raises():
    bad_args = [
        {"smdistributed": "dummy"},
        {"smdistributed": {"dummy"}},
        {"smdistributed": {"dummy": "val"}},
        {"smdistributed": {"dummy": {"enabled": True}}},
    ]
    instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
    frameworks = ["tensorflow", "pytorch"]
    for framework, distribution, instance_type in product(frameworks, bad_args, instance_types):
        with pytest.raises(ValueError):
            fw_utils.validate_smdistributed(
                instance_type=instance_type,
                framework_name=framework,
                framework_version=None,
                py_version=None,
                distribution=distribution,
                image_uri="custom-container",
            )


def test_validate_smdataparallel_args_raises():
    # TODO: add validation for dataparallel in mxnet
    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}

    # Cases {PT|TF2}
    # 1. None instance type
    # 2. incorrect instance type
    # 3. incorrect python version
    # 4. incorrect framework version

    bad_args = [
        (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
        ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
        ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
        (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
    ]
    for instance_type, framework_name, framework_version, py_version, distribution in bad_args:
        with pytest.raises(ValueError):
            fw_utils._validate_smdataparallel_args(
                instance_type, framework_name, framework_version, py_version, distribution
            )


def test_validate_smdataparallel_args_not_raises():
    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
    smdataparallel_enabled_custom_mpi = {
        "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
    }
    smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}

    # Cases {PT|TF2}
    # 1. SM Distributed dataparallel disabled
    # 2. SM Distributed dataparallel enabled with supported args

    good_args = [
        (None, None, None, None, smdataparallel_disabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3.1", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3.2", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.5.0", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.5", "py37", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.6.0", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.6.2", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.6.3", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.6", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.7.1", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.7", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.9", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.10.1", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.10", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.11.0", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.11", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.8.1", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.9", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.10.0", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.11", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py37", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.6.0", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.6.2", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.6.3", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.7.1", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.10.1", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "tensorflow", "2.11.0", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
        ("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
    ]
    for instance_type, framework_name, framework_version, py_version, distribution in good_args:
        fw_utils._validate_smdataparallel_args(
            instance_type, framework_name, framework_version, py_version, distribution
        )


def test_validate_pytorchddp_not_raises():
    # Case 1: Framework is not PyTorch
    fw_utils.validate_pytorch_distribution(
        distribution=None,
        framework_name="tensorflow",
        framework_version="2.9.1",
        py_version="py3",
        image_uri="custom-container",
    )
    # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
    pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
    fw_utils.validate_pytorch_distribution(
        distribution=pytorchddp_disabled,
        framework_name="pytorch",
        framework_version="1.10",
        py_version="py3",
        image_uri="custom-container",
    )
    # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
    pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
    pytorchddp_supported_fw_versions = [
        "1.10",
        "1.10.0",
        "1.10.2",
        "1.11",
        "1.11.0",
        "1.12",
        "1.12.0",
        "1.12.1",
    ]
    for framework_version in pytorchddp_supported_fw_versions:
        fw_utils.validate_pytorch_distribution(
            distribution=pytorchddp_enabled,
            framework_name="pytorch",
            framework_version=framework_version,
            py_version="py3",
            image_uri="custom-container",
        )


def test_validate_pytorchddp_raises():
    pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
    # Case 1: Unsupported framework version
    with pytest.raises(ValueError):
        fw_utils.validate_pytorch_distribution(
            distribution=pytorchddp_enabled,
            framework_name="pytorch",
            framework_version="1.8",
            py_version="py3",
            image_uri=None,
        )

    # Case 2: Unsupported Py version
    with pytest.raises(ValueError):
        fw_utils.validate_pytorch_distribution(
            distribution=pytorchddp_enabled,
            framework_name="pytorch",
            framework_version="1.10",
            py_version="py2",
            image_uri=None,
        )


def test_validate_torch_distributed_not_raises():
    # Case 1: Framework is PyTorch, but torch_distributed is not enabled
    torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
    fw_utils.validate_torch_distributed_distribution(
        instance_type="ml.trn1.2xlarge",
        distribution=torch_distributed_disabled,
        framework_version="1.11.0",
        py_version="py3",
        image_uri=None,
        entry_point="train.py",
    )
    # Case 2: Distribution is torch_distributed enabled, supported framework and py versions
    torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
    torch_distributed_supported_fw_versions = [
        "1.11.0",
    ]
    for framework_version in torch_distributed_supported_fw_versions:
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.trn1.2xlarge",
            distribution=torch_distributed_enabled,
            framework_version=framework_version,
            py_version="py3",
            image_uri=None,
            entry_point="train.py",
        )

    # Case 3: Distribution is torch_distributed enabled, supported framework and instances
    torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
    torch_distributed_gpu_supported_fw_versions = [
        "1.13.1",
        "2.0.0",
    ]
    for framework_version in torch_distributed_gpu_supported_fw_versions:
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.p3.8xlarge",
            distribution=torch_distributed_enabled,
            framework_version=framework_version,
            py_version="py3",
            image_uri=None,
            entry_point="train.py",
        )


def test_validate_torch_distributed_raises():
    torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
    # Case 1: Unsupported framework version
    with pytest.raises(ValueError):
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.trn1.2xlarge",
            distribution=torch_distributed_enabled,
            framework_version="1.10.0",
            py_version="py3",
            image_uri=None,
            entry_point="train.py",
        )

    # Case 2: Unsupported Py version
    with pytest.raises(ValueError):
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.trn1.2xlarge",
            distribution=torch_distributed_enabled,
            framework_version="1.11.0",
            py_version="py2",
            image_uri=None,
            entry_point="train.py",
        )

    # Case 3: Unsupported Entry point type
    with pytest.raises(ValueError):
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.trn1.2xlarge",
            distribution=torch_distributed_enabled,
            framework_version="1.11.0",
            py_version="py3",
            image_uri=None,
            entry_point="train.sh",
        )

    # Case 4: Unsupported framework version for gpu instances
    with pytest.raises(ValueError):
        fw_utils.validate_torch_distributed_distribution(
            instance_type="ml.p3.8xlarge",
            distribution=torch_distributed_enabled,
            framework_version="1.11.0",
            py_version="py3",
            image_uri=None,
            entry_point="train.py",
        )


def test_validate_unsupported_distributions_trainium_raises():
    with pytest.raises(ValueError):
        mpi_enabled = {"mpi": {"enabled": True}}
        fw_utils.validate_distribution_for_instance_type(
            distribution=mpi_enabled,
            instance_type="ml.trn1.2xlarge",
        )

    with pytest.raises(ValueError):
        mpi_enabled = {"mpi": {"enabled": True}}
        fw_utils.validate_distribution_for_instance_type(
            distribution=mpi_enabled,
            instance_type="ml.trn1.32xlarge",
        )

    with pytest.raises(ValueError):
        pytorch_ddp_enabled = {"pytorch_ddp": {"enabled": True}}
        fw_utils.validate_distribution_for_instance_type(
            distribution=pytorch_ddp_enabled,
            instance_type="ml.trn1.32xlarge",
        )

    with pytest.raises(ValueError):
        smdataparallel_enabled = {"smdataparallel": {"enabled": True}}
        fw_utils.validate_distribution_for_instance_type(
            distribution=smdataparallel_enabled,
            instance_type="ml.trn1.32xlarge",
        )


def test_instance_type_supports_profiler():
    assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
    assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
    assert fw_utils._instance_type_supports_profiler("local") is False


def test_is_gpu_instance():
    gpu_instance_types = [
        "ml.p3.2xlarge",
        "ml.p3.8xlarge",
        "ml.p3.16xlarge",
        "ml.p3dn.24xlarge",
        "ml.p4d.24xlarge",
        "ml.p4de.24xlarge",
        "ml.g4dn.xlarge",
        "ml.g5.xlarge",
        "ml.g5.48xlarge",
        "local_gpu",
    ]
    non_gpu_instance_types = [
        "ml.t3.xlarge",
        "ml.m5.8xlarge",
        "ml.m5d.16xlarge",
        "ml.c5.9xlarge",
        "ml.r5.8xlarge",
    ]
    for gpu_type in gpu_instance_types:
        assert fw_utils._is_gpu_instance(gpu_type) is True
    for non_gpu_type in non_gpu_instance_types:
        assert fw_utils._is_gpu_instance(non_gpu_type) is False


def test_is_trainium_instance():
    trainium_instance_types = [
        "ml.trn1.2xlarge",
        "ml.trn1.32xlarge",
    ]
    non_trainum_instance_types = [
        "ml.t3.xlarge",
        "ml.m5.8xlarge",
        "ml.m5d.16xlarge",
        "ml.c5.9xlarge",
        "ml.r5.8xlarge",
        "ml.p3.2xlarge",
        "ml.p3.8xlarge",
        "ml.p3.16xlarge",
        "ml.p3dn.24xlarge",
        "ml.p4d.24xlarge",
        "ml.p4de.24xlarge",
        "ml.g4dn.xlarge",
        "ml.g5.xlarge",
        "ml.g5.48xlarge",
        "local_gpu",
    ]
    for tr_type in trainium_instance_types:
        assert fw_utils._is_trainium_instance(tr_type) is True
    for non_tr_type in non_trainum_instance_types:
        assert fw_utils._is_trainium_instance(non_tr_type) is False