# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

import sagemaker
import tests.integ
import tests.integ.timeout
from sagemaker.model_monitor import DataCaptureConfig, NetworkConfig
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.utils import unique_name_from_base
from tests.integ.retry import retries

ROLE = "SageMakerRole"
SKLEARN_FRAMEWORK = "scikit-learn"

INSTANCE_COUNT = 1
INSTANCE_TYPE = "ml.m5.xlarge"
VOLUME_SIZE_IN_GB = 20
MAX_RUNTIME_IN_SECONDS = 2 * 60 * 60
ENVIRONMENT = {"env_key_1": "env_value_1"}
TAGS = [{"Key": "tag_key_1", "Value": "tag_value_1"}]
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=True)

CUSTOM_SAMPLING_PERCENTAGE = 10
CUSTOM_CAPTURE_OPTIONS = ["REQUEST"]
CUSTOM_CSV_CONTENT_TYPES = ["text/csvtype1", "text/csvtype2"]
CUSTOM_JSON_CONTENT_TYPES = ["application/jsontype1", "application/jsontype2"]


def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
    sagemaker_session, tensorflow_inference_latest_version
):
    endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
    model_data = sagemaker_session.upload_data(
        path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
        key_prefix="tensorflow-serving/models",
    )
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model = TensorFlowModel(
            model_data=model_data,
            role=ROLE,
            framework_version=tensorflow_inference_latest_version,
            sagemaker_session=sagemaker_session,
        )
        predictor = model.deploy(
            initial_instance_count=INSTANCE_COUNT,
            instance_type=INSTANCE_TYPE,
            endpoint_name=endpoint_name,
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc.get("DataCaptureConfig") is None

        predictor.enable_data_capture()

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 25 retries * 60s sleeps = 25min timeout
        for _ in retries(
            max_retry_count=25,
            exception_message_prefix="Waiting for 'InService' endpoint status",
            seconds_to_sleep=60,
        ):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=predictor.endpoint_name
            )
            if new_endpoint["EndpointStatus"] == "InService":
                break

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]


def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
    sagemaker_session, tensorflow_inference_latest_version
):
    endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
    model_data = sagemaker_session.upload_data(
        path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
        key_prefix="tensorflow-serving/models",
    )
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model = TensorFlowModel(
            model_data=model_data,
            role=ROLE,
            framework_version=tensorflow_inference_latest_version,
            sagemaker_session=sagemaker_session,
        )
        destination_s3_uri = os.path.join(
            "s3://", sagemaker_session.default_bucket(), endpoint_name, "custom"
        )
        predictor = model.deploy(
            initial_instance_count=INSTANCE_COUNT,
            instance_type=INSTANCE_TYPE,
            endpoint_name=endpoint_name,
            data_capture_config=DataCaptureConfig(
                enable_capture=True,
                sampling_percentage=CUSTOM_SAMPLING_PERCENTAGE,
                destination_s3_uri=destination_s3_uri,
                capture_options=CUSTOM_CAPTURE_OPTIONS,
                csv_content_types=CUSTOM_CSV_CONTENT_TYPES,
                json_content_types=CUSTOM_JSON_CONTENT_TYPES,
                sagemaker_session=sagemaker_session,
            ),
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]
        assert (
            endpoint_config_desc["DataCaptureConfig"]["InitialSamplingPercentage"]
            == CUSTOM_SAMPLING_PERCENTAGE
        )
        assert endpoint_config_desc["DataCaptureConfig"]["CaptureOptions"] == [
            {"CaptureMode": "Input"}
        ]
        assert (
            endpoint_config_desc["DataCaptureConfig"]["CaptureContentTypeHeader"]["CsvContentTypes"]
            == CUSTOM_CSV_CONTENT_TYPES
        )
        assert (
            endpoint_config_desc["DataCaptureConfig"]["CaptureContentTypeHeader"][
                "JsonContentTypes"
            ]
            == CUSTOM_JSON_CONTENT_TYPES
        )

        predictor.disable_data_capture()

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 25 retries * 60s sleeps = 25min timeout
        for _ in retries(
            max_retry_count=25,
            exception_message_prefix="Waiting for 'InService' endpoint status",
            seconds_to_sleep=60,
        ):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=predictor.endpoint_name
            )
            if new_endpoint["EndpointStatus"] == "InService":
                break

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert not endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]


def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
    sagemaker_session, tensorflow_inference_latest_version
):
    endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
    model_data = sagemaker_session.upload_data(
        path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
        key_prefix="tensorflow-serving/models",
    )
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model = TensorFlowModel(
            model_data=model_data,
            role=ROLE,
            framework_version=tensorflow_inference_latest_version,
            sagemaker_session=sagemaker_session,
        )
        destination_s3_uri = os.path.join(
            "s3://", sagemaker_session.default_bucket(), endpoint_name, "custom"
        )
        predictor = model.deploy(
            initial_instance_count=INSTANCE_COUNT,
            instance_type=INSTANCE_TYPE,
            endpoint_name=endpoint_name,
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc.get("DataCaptureConfig") is None

        predictor.update_data_capture_config(
            data_capture_config=DataCaptureConfig(
                enable_capture=True,
                sampling_percentage=CUSTOM_SAMPLING_PERCENTAGE,
                destination_s3_uri=destination_s3_uri,
                capture_options=CUSTOM_CAPTURE_OPTIONS,
                csv_content_types=CUSTOM_CSV_CONTENT_TYPES,
                json_content_types=CUSTOM_JSON_CONTENT_TYPES,
                sagemaker_session=sagemaker_session,
            )
        )

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 25 retries * 60s sleeps = 25min timeout
        for _ in retries(
            max_retry_count=25,
            exception_message_prefix="Waiting for 'InService' endpoint status",
            seconds_to_sleep=60,
        ):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=predictor.endpoint_name
            )
            if new_endpoint["EndpointStatus"] == "InService":
                break

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]
        assert (
            endpoint_config_desc["DataCaptureConfig"]["InitialSamplingPercentage"]
            == CUSTOM_SAMPLING_PERCENTAGE
        )
        assert endpoint_config_desc["DataCaptureConfig"]["CaptureOptions"] == [
            {"CaptureMode": "Input"}
        ]
        assert (
            endpoint_config_desc["DataCaptureConfig"]["CaptureContentTypeHeader"]["CsvContentTypes"]
            == CUSTOM_CSV_CONTENT_TYPES
        )
        assert (
            endpoint_config_desc["DataCaptureConfig"]["CaptureContentTypeHeader"][
                "JsonContentTypes"
            ]
            == CUSTOM_JSON_CONTENT_TYPES
        )