# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

import numpy as np
import pytest
import xgboost as xgb
from mock import MagicMock, patch
from sagemaker_containers.beta.framework import content_types, encoders, errors

from sagemaker_algorithm_toolkit.exceptions import UserError
from sagemaker_xgboost_container import serving
from sagemaker_xgboost_container.constants import sm_env_constants

TEST_CONFIG_FILE = "test_dir"


@pytest.fixture(autouse=True)
def mock_set_mms_config_file(monkeypatch):
    monkeypatch.setenv("XGBOOST_MMS_CONFIG", TEST_CONFIG_FILE)


@pytest.fixture(scope="module", name="np_array")
def fixture_np_array():
    return np.ones((2, 2))


class FakeEstimator:
    def __init__(self):
        pass

    @staticmethod
    def predict(input):
        return


@pytest.mark.parametrize("csv_array", ("42,6,9", "42.0,6.0,9.0"))
def test_input_fn_dmatrix(csv_array):
    deserialized_csv_array = serving.default_input_fn(csv_array, content_types.CSV)
    assert type(deserialized_csv_array) is xgb.DMatrix


def test_input_fn_bad_content_type():
    with pytest.raises(errors.UnsupportedFormatError):
        serving.default_input_fn("", "application/not_supported")


def test_default_model_fn():
    with pytest.raises(NotImplementedError):
        serving.default_model_fn("model_dir")


def test_predict_fn(np_array):
    mock_estimator = FakeEstimator()
    with patch.object(mock_estimator, "predict") as mock:
        serving.default_predict_fn(np_array, mock_estimator)
    mock.assert_called_once()


def test_output_fn_json(np_array):
    response = serving.default_output_fn(np_array, content_types.JSON)

    assert response.get_data(as_text=True) == encoders.array_to_json(np_array.tolist())
    assert response.content_type == content_types.JSON


def test_output_fn_csv(np_array):
    response = serving.default_output_fn(np_array, content_types.CSV)

    assert response.get_data(as_text=True) == "1.0,1.0\n1.0,1.0\n"
    # TODO This is a workaround to get the test passsing.
    # Not sure if it is related to executing tests on Mac in specific virtual environment,
    # but the content type in response is: 'text/csv; charset=utf-8' instead of the expected: text/csv
    assert content_types.CSV in response.content_type


def test_output_fn_npz(np_array):
    response = serving.default_output_fn(np_array, content_types.NPY)

    assert response.get_data() == encoders.array_to_npy(np_array)
    assert response.content_type == content_types.NPY


def test_input_fn_bad_accept():
    with pytest.raises(errors.UnsupportedFormatError):
        serving.default_output_fn("", "application/not_supported")


@patch("sagemaker_xgboost_container.serving.server")
def test_serving_entrypoint_start_gunicorn(mock_server):
    mock_server.start = MagicMock()
    serving.serving_entrypoint()
    mock_server.start.assert_called_once()


@patch("sagemaker_xgboost_container.serving.server")
@patch("sagemaker_xgboost_container.serving.set_default_serving_env_if_unspecified")
def test_serving_entrypoint_set_default_env_positive(mock_set_default_serving_env_if_unspecified, mock_server):
    serving.serving_entrypoint()
    mock_set_default_serving_env_if_unspecified.assert_called_once()
    assert os.getenv("OMP_NUM_THREADS") == sm_env_constants.ONE_THREAD_PER_PROCESS


@patch("sagemaker_xgboost_container.serving.server")
@patch("sagemaker_xgboost_container.serving.set_default_serving_env_if_unspecified")
def test_serving_entrypoint_set_default_env_negative(mock_set_default_serving_env_if_unspecified, mock_server):
    with patch.dict(os.environ, {"OMP_NUM_THREADS": "USER_SPECIFIED_VALUE"}, clear=True):
        serving.serving_entrypoint()
        mock_set_default_serving_env_if_unspecified.assert_called_once()
        assert os.getenv("OMP_NUM_THREADS") == "USER_SPECIFIED_VALUE"


@patch.dict(
    os.environ,
    {
        "SAGEMAKER_MULTI_MODEL": "True",
    },
)
@patch("sagemaker_xgboost_container.serving.start_mxnet_model_server")
def test_serving_entrypoint_start_mms(mock_start_mxnet_model_server):
    serving.serving_entrypoint()
    mock_start_mxnet_model_server.assert_called_once()


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_transform_and_other_fn(mock_transformer):
    mock_module = MagicMock(spec=["model_fn", "transform_fn", "input_fn"])
    with pytest.raises(UserError):
        serving._user_module_transformer(mock_module)


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_transform_and_no_other_fn(mock_transformer):
    mock_module = MagicMock(spec=["model_fn", "transform_fn"])
    serving._user_module_transformer(mock_module)
    mock_transformer.Transformer.assert_called_once_with(
        model_fn=mock_module.model_fn, transform_fn=mock_module.transform_fn
    )


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_model_fn_only(mock_transformer):
    mock_module = MagicMock(spec=["model_fn"])
    serving._user_module_transformer(mock_module)
    mock_transformer.Transformer.assert_called_once_with(
        model_fn=mock_module.model_fn,
        input_fn=serving.default_input_fn,
        predict_fn=serving.default_predict_fn,
        output_fn=serving.default_output_fn,
    )


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_input_fn(mock_transformer):
    mock_module = MagicMock(spec=["model_fn", "input_fn"])
    serving._user_module_transformer(mock_module)
    mock_transformer.Transformer.assert_called_once_with(
        model_fn=mock_module.model_fn,
        input_fn=mock_module.input_fn,
        predict_fn=serving.default_predict_fn,
        output_fn=serving.default_output_fn,
    )


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_predict_fn(mock_transformer):
    mock_module = MagicMock(spec=["model_fn", "predict_fn"])
    serving._user_module_transformer(mock_module)
    mock_transformer.Transformer.assert_called_once_with(
        model_fn=mock_module.model_fn,
        input_fn=serving.default_input_fn,
        predict_fn=mock_module.predict_fn,
        output_fn=serving.default_output_fn,
    )


@patch("sagemaker_xgboost_container.serving.transformer")
def test_user_module_transformer_with_output_fn(mock_transformer):
    mock_module = MagicMock(spec=["model_fn", "output_fn"])
    serving._user_module_transformer(mock_module)
    mock_transformer.Transformer.assert_called_once_with(
        model_fn=mock_module.model_fn,
        input_fn=serving.default_input_fn,
        predict_fn=serving.default_predict_fn,
        output_fn=mock_module.output_fn,
    )