# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import os

from mock import patch
import pytest

from sagemaker_inference import environment, parameters


@patch.dict(
    os.environ,
    {
        parameters.USER_PROGRAM_ENV: "main.py",
        parameters.MODEL_SERVER_TIMEOUT_ENV: "20",
        parameters.MODEL_SERVER_TIMEOUT_SECONDS_ENV: "30",
        parameters.MODEL_SERVER_WORKERS_ENV: "8",
        parameters.STARTUP_TIMEOUT_ENV: "50",
        parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html",
        parameters.BIND_TO_PORT_ENV: "1738",
        parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
        parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport",
        parameters.MAX_REQUEST_SIZE: "10",
    },
    clear=True,
)
def test_env():
    env = environment.Environment()

    assert environment.base_dir.endswith("/opt/ml")
    assert environment.model_dir.endswith("/opt/ml/model")
    assert environment.code_dir.endswith("opt/ml/model/code")
    assert env.module_name == "main"
    assert env.model_server_timeout == 20
    assert env.model_server_timeout_seconds == 30
    assert env.startup_timeout == 50
    assert env.model_server_workers == "8"
    assert env.default_accept == "text/html"
    assert env.inference_http_port == "1738"
    assert env.management_http_port == "1738"
    assert env.safe_port_range == "1111-2222"
    assert "-XX:-UseContainerSupport" in env.vmargs
    assert env.max_request_size == 10 * 1024 * 1024


@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])
@patch.dict(os.environ, {}, clear=True)
def test_env_module_name(sagemaker_program):
    os.environ[parameters.USER_PROGRAM_ENV] = sagemaker_program
    module_name = environment.Environment().module_name

    del os.environ[parameters.USER_PROGRAM_ENV]

    assert module_name == "program"