# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from mock import Mock
from botocore.exceptions import ClientError
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.async_inference import AsyncInferenceResponse
from sagemaker.exceptions import (
    AsyncInferenceModelError,
    ObjectNotExistedError,
    UnexpectedClientError,
)

DEFAULT_OUTPUT_PATH = "s3://some-output-path/object-name"
DEFAULT_FAILURE_PATH = "s3://some-failure-path/object-name"
ENDPOINT_NAME = "some-endpoint-name"
RETURN_VALUE = 0


def empty_s3_client():
    """
    Returns a mocked S3 client with the `get_object` method overridden
    to raise different exceptions based on the input.

    Exceptions raised:
    - `ClientError` with code "NoSuchKey"
    - `AsyncInferenceModelError`
    - `ObjectNotExistedError`
    - `ClientError` with code "SomeOtherError"
    - `UnexpectedClientError`

    """
    s3_client = Mock(name="s3-client")

    client_error_no_such_key = ClientError(
        error_response={"Error": {"Code": "NoSuchKey"}},
        operation_name="async-inference-response-test",
    )

    async_error = AsyncInferenceModelError("some error message")

    object_error = ObjectNotExistedError("some error message", DEFAULT_OUTPUT_PATH)

    client_error_other = ClientError(
        error_response={"Error": {"Code": "SomeOtherError", "Message": "some error message"}},
        operation_name="async-inference-response-test",
    )

    unexpected_error = UnexpectedClientError("some error message")

    s3_client.get_object = Mock(
        name="get_object",
        side_effect=[
            client_error_no_such_key,
            async_error,
            object_error,
            client_error_other,
            unexpected_error,
        ],
    )
    return s3_client


def empty_s3_client_to_verify_exceptions_for_null_failure_path():
    """
    Returns a mocked S3 client with the `get_object` method overridden
    to raise different exceptions based on the input.

    Exceptions raised:
    - `ObjectNotExistedError`
    - `UnexpectedClientError`

    """
    s3_client = Mock(name="s3-client")

    object_error = ObjectNotExistedError("Inference could still be running", DEFAULT_OUTPUT_PATH)

    unexpected_error = UnexpectedClientError("some error message")

    s3_client.get_object = Mock(
        name="get_object",
        side_effect=[
            object_error,
            unexpected_error,
        ],
    )
    return s3_client


def mock_s3_client():
    """
    This function returns a mocked S3 client object that has a get_object method with a side_effect
    that returns a dictionary with a Body key that points to a mocked response body object.
    """
    s3_client = Mock(name="s3-client")
    response_body = Mock("body")
    response_body.read = Mock("read", return_value=RETURN_VALUE)
    response_body.close = Mock("close", return_value=None)
    s3_client.get_object = Mock(
        name="get_object",
        side_effect=[
            {"Body": response_body},
        ],
    )
    return s3_client


def empty_deserializer():
    deserializer = Mock(name="deserializer")
    deserializer.deserialize = Mock(name="deserialize", return_value=RETURN_VALUE)
    return deserializer


def test_init_():
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH,
        predictor_async=predictor_async,
        failure_path=DEFAULT_FAILURE_PATH,
    )
    assert async_inference_response.output_path == DEFAULT_OUTPUT_PATH
    assert async_inference_response.failure_path == DEFAULT_FAILURE_PATH


def test_wrong_waiter_config_object():
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH,
        predictor_async=predictor_async,
        failure_path=DEFAULT_FAILURE_PATH,
    )

    with pytest.raises(
        ValueError,
        match="waiter_config should be a WaiterConfig object",
    ):
        async_inference_response.get_result(waiter_config={})


def test_get_result_success():
    """
    verifies that the result is returned correctly if no errors occur.
    """
    # Initialize AsyncInferenceResponse
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    predictor_async.s3_client = mock_s3_client()
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH,
        predictor_async=predictor_async,
        failure_path=DEFAULT_FAILURE_PATH,
    )

    result = async_inference_response.get_result()
    assert async_inference_response._result == result
    assert result == RETURN_VALUE


def test_get_result_verify_exceptions():
    """
    Verifies that get_result method raises the expected exception
    when an error occurs while fetching the result.
    """
    # Initialize AsyncInferenceResponse
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    predictor_async.s3_client = empty_s3_client()
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH,
        predictor_async=predictor_async,
        failure_path=DEFAULT_FAILURE_PATH,
    )

    # Test AsyncInferenceModelError
    with pytest.raises(AsyncInferenceModelError, match="Model returned error: some error message"):
        async_inference_response.get_result()

    # Test ObjectNotExistedError
    with pytest.raises(
        ObjectNotExistedError,
        match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. some error message",
    ):
        async_inference_response.get_result()

    # Test UnexpectedClientError
    with pytest.raises(
        UnexpectedClientError, match="Encountered unexpected client error: some error message"
    ):
        async_inference_response.get_result()


def test_get_result_with_null_failure_path():
    """
    verifies that the result is returned correctly if no errors occur.
    """
    # Initialize AsyncInferenceResponse
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    predictor_async.s3_client = mock_s3_client()
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, failure_path=None
    )

    result = async_inference_response.get_result()
    assert async_inference_response._result == result
    assert result == RETURN_VALUE


def test_get_result_verify_exceptions_with_null_failure_path():
    """
    Verifies that get_result method raises the expected exception
    when an error occurs while fetching the result.
    """
    # Initialize AsyncInferenceResponse
    predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
    predictor_async.s3_client = empty_s3_client_to_verify_exceptions_for_null_failure_path()
    async_inference_response = AsyncInferenceResponse(
        output_path=DEFAULT_OUTPUT_PATH,
        predictor_async=predictor_async,
        failure_path=None,
    )

    # Test ObjectNotExistedError
    with pytest.raises(
        ObjectNotExistedError,
        match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. Inference could still be running",
    ):
        async_inference_response.get_result()

    # Test UnexpectedClientError
    with pytest.raises(
        UnexpectedClientError, match="Encountered unexpected client error: some error message"
    ):
        async_inference_response.get_result()