# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from mock import patch
import numpy as np
import pytest

from sagemaker_inference import (content_types, encoder, errors)
from sklearn.base import BaseEstimator

from sagemaker_sklearn_container.handler_service import HandlerService


handler = HandlerService().DefaultSKLearnUserModuleInferenceHandler()


@pytest.fixture(scope='module', name='np_array')
def fixture_np_array():
    return np.ones((2, 2))


class FakeEstimator(BaseEstimator):
    def __init__(self):
        pass

    @staticmethod
    def predict(input):
        return


@pytest.mark.parametrize(
    'json_data, expected', [
        ('[42, 6, 9]', np.array([[42, 6, 9]])),
        ('[42.0, 6.0, 9.0]', np.array([[42., 6., 9.]])),
        ('["42", "6", "9"]', np.array([['42', '6', '9']], dtype=np.float32)),
        (u'["42", "6", "9"]', np.array([[u'42', u'6', u'9']], dtype=np.float32))])
def test_input_fn_json(json_data, expected):
    actual = handler.default_input_fn(json_data, content_types.JSON)
    np.testing.assert_equal(actual, expected)


@pytest.mark.parametrize(
    'csv_data, expected', [
        ('42\n6\n9\n', np.array([[42, 6, 9]], dtype=np.float32)),
        ('42.0\n6.0\n9.0\n', np.array([[42., 6., 9.]], dtype=np.float32)),
        ('42\n6\n9\n', np.array([[42, 6, 9]], dtype=np.float32))])
def test_input_fn_csv(csv_data, expected):
    deserialized_np_array = handler.default_input_fn(csv_data, content_types.CSV)
    assert np.array_equal(expected, deserialized_np_array)


@pytest.mark.parametrize('np_array, expected', [
    ([42, 6, 9], np.array([[42, 6, 9]], dtype=np.float32)),
    ([42., 6., 9.], np.array([[42, 6, 9]], dtype=np.float32))])
def test_input_fn_npz(np_array, expected):
    input_data = encoder._array_to_npy(np_array)
    deserialized_np_array = handler.default_input_fn(input_data, content_types.NPY)

    assert np.array_equal(expected, deserialized_np_array)

    float_32_array = np.array(np_array, dtype=np.float32)
    input_data = encoder._array_to_npy(float_32_array)
    deserialized_np_array = handler.default_input_fn(input_data, content_types.NPY)

    assert np.array_equal(expected, deserialized_np_array)

    float_64_array = np.array(np_array, dtype=np.float64)
    input_data = encoder._array_to_npy(float_64_array)
    deserialized_np_array = handler.default_input_fn(input_data, content_types.NPY)

    assert np.array_equal(expected, deserialized_np_array)


def test_input_fn_bad_content_type():
    with pytest.raises(errors.UnsupportedFormatError):
        handler.default_input_fn('', 'application/not_supported')


def test_default_model_fn():
    with pytest.raises(NotImplementedError):
        handler.default_model_fn('model_dir')


def test_predict_fn(np_array):
    mock_estimator = FakeEstimator()
    with patch.object(mock_estimator, 'predict') as mock:
        handler.default_predict_fn(np_array, mock_estimator)
    mock.assert_called_once()


def test_output_fn_json(np_array):
    response = handler.default_output_fn(np_array, content_types.JSON)
    assert response == (encoder._array_to_json(np_array.tolist()), content_types.JSON)


def test_output_fn_csv(np_array):
    response = handler.default_output_fn(np_array, content_types.CSV)
    assert response == ('1.0,1.0\n1.0,1.0\n', content_types.CSV)


def test_output_fn_npz(np_array):
    response = handler.default_output_fn(np_array, content_types.NPY)
    assert response == (encoder._array_to_npy(np_array), content_types.NPY)


def test_input_fn_bad_accept():
    with pytest.raises(errors.UnsupportedFormatError):
        handler.default_output_fn('', 'application/not_supported')