# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import

import logging
import os
import platform
from datetime import datetime

import boto3
from botocore.exceptions import ClientError

from sagemaker.config import (
    load_sagemaker_config,
    validate_sagemaker_config,
    SESSION_DEFAULT_S3_BUCKET_PATH,
    SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
)
from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import get_docker_host
from sagemaker.local.entities import (
    _LocalEndpointConfig,
    _LocalEndpoint,
    _LocalModel,
    _LocalProcessingJob,
    _LocalTrainingJob,
    _LocalTransformJob,
    _LocalPipeline,
)
from sagemaker.session import Session
from sagemaker.utils import get_config_value, _module_import_error, resolve_value_from_config

logger = logging.getLogger(__name__)


class LocalSagemakerClient(object):  # pylint: disable=too-many-public-methods
    """A SageMakerClient that implements the API calls locally.

    Used for doing local training and hosting local endpoints. It still needs access to
    a boto client to interact with S3 but it won't perform any SageMaker call.

    Implements the methods with the same signature as the boto SageMakerClient.

    Args:

    Returns:

    """

    _processing_jobs = {}
    _training_jobs = {}
    _transform_jobs = {}
    _models = {}
    _endpoint_configs = {}
    _endpoints = {}
    _pipelines = {}

    def __init__(self, sagemaker_session=None):
        """Initialize a LocalSageMakerClient.

        Args:
            sagemaker_session (sagemaker.session.Session): a session to use to read configurations
                from, and use its boto client.
        """
        self.sagemaker_session = sagemaker_session or LocalSession()

    def create_processing_job(
        self,
        ProcessingJobName,
        AppSpecification,
        ProcessingResources,
        Environment=None,
        ProcessingInputs=None,
        ProcessingOutputConfig=None,
        **kwargs
    ):
        """Creates a processing job in Local Mode

        Args:
          ProcessingJobName(str): local processing job name.
          AppSpecification(dict): Identifies the container and application to run.
          ProcessingResources(dict): Identifies the resources to use for local processing.
          Environment(dict, optional): Describes the environment variables to pass
            to the container. (Default value = None)
          ProcessingInputs(dict, optional): Describes the processing input data.
            (Default value = None)
          ProcessingOutputConfig(dict, optional): Describes the processing output
            configuration. (Default value = None)
          **kwargs: Keyword arguments

        Returns:

        """
        Environment = Environment or {}
        ProcessingInputs = ProcessingInputs or []
        ProcessingOutputConfig = ProcessingOutputConfig or {}

        container_entrypoint = None
        if "ContainerEntrypoint" in AppSpecification:
            container_entrypoint = AppSpecification["ContainerEntrypoint"]

        container_arguments = None
        if "ContainerArguments" in AppSpecification:
            container_arguments = AppSpecification["ContainerArguments"]

        if "ExperimentConfig" in kwargs:
            logger.warning("Experiment configuration is not supported in local mode.")
        if "NetworkConfig" in kwargs:
            logger.warning("Network configuration is not supported in local mode.")
        if "StoppingCondition" in kwargs:
            logger.warning("Stopping condition is not supported in local mode.")

        container = _SageMakerContainer(
            ProcessingResources["ClusterConfig"]["InstanceType"],
            ProcessingResources["ClusterConfig"]["InstanceCount"],
            AppSpecification["ImageUri"],
            sagemaker_session=self.sagemaker_session,
            container_entrypoint=container_entrypoint,
            container_arguments=container_arguments,
        )
        processing_job = _LocalProcessingJob(container)
        logger.info("Starting processing job")
        processing_job.start(
            ProcessingInputs, ProcessingOutputConfig, Environment, ProcessingJobName
        )

        LocalSagemakerClient._processing_jobs[ProcessingJobName] = processing_job

    def describe_processing_job(self, ProcessingJobName):
        """Describes a local processing job.

        Args:
          ProcessingJobName(str): Processing job name to describe.
        Returns: (dict) DescribeProcessingJob Response.

        Returns:

        """
        if ProcessingJobName not in LocalSagemakerClient._processing_jobs:
            error_response = {
                "Error": {
                    "Code": "ValidationException",
                    "Message": "Could not find local processing job",
                }
            }
            raise ClientError(error_response, "describe_processing_job")
        return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe()

    def create_training_job(
        self,
        TrainingJobName,
        AlgorithmSpecification,
        OutputDataConfig,
        ResourceConfig,
        InputDataConfig=None,
        Environment=None,
        **kwargs
    ):
        """Create a training job in Local Mode.

        Args:
          TrainingJobName(str): local training job name.
          AlgorithmSpecification(dict): Identifies the training algorithm to use.
          InputDataConfig(dict, optional): Describes the training dataset and the location where
            it is stored. (Default value = None)
          OutputDataConfig(dict): Identifies the location where you want to save the results of
            model training.
          ResourceConfig(dict): Identifies the resources to use for local model training.
          Environment(dict, optional): Describes the environment variables to pass
            to the container. (Default value = None)
          HyperParameters(dict) [optional]: Specifies these algorithm-specific parameters to
            influence the quality of the final model.
          **kwargs:

        Returns:

        """
        InputDataConfig = InputDataConfig or {}
        Environment = Environment or {}
        container = _SageMakerContainer(
            ResourceConfig["InstanceType"],
            ResourceConfig["InstanceCount"],
            AlgorithmSpecification["TrainingImage"],
            sagemaker_session=self.sagemaker_session,
        )
        training_job = _LocalTrainingJob(container)
        hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
        logger.info("Starting training job")
        training_job.start(
            InputDataConfig, OutputDataConfig, hyperparameters, Environment, TrainingJobName
        )

        LocalSagemakerClient._training_jobs[TrainingJobName] = training_job

    def describe_training_job(self, TrainingJobName):
        """Describe a local training job.

        Args:
          TrainingJobName(str): Training job name to describe.
        Returns: (dict) DescribeTrainingJob Response.

        Returns:

        """
        if TrainingJobName not in LocalSagemakerClient._training_jobs:
            error_response = {
                "Error": {
                    "Code": "ValidationException",
                    "Message": "Could not find local training job",
                }
            }
            raise ClientError(error_response, "describe_training_job")
        return LocalSagemakerClient._training_jobs[TrainingJobName].describe()

    def create_transform_job(
        self,
        TransformJobName,
        ModelName,
        TransformInput,
        TransformOutput,
        TransformResources,
        **kwargs
    ):
        """Create the transform job.

        Args:
          TransformJobName:
          ModelName:
          TransformInput:
          TransformOutput:
          TransformResources:
          **kwargs:

        Returns:

        """
        transform_job = _LocalTransformJob(TransformJobName, ModelName, self.sagemaker_session)
        LocalSagemakerClient._transform_jobs[TransformJobName] = transform_job
        transform_job.start(TransformInput, TransformOutput, TransformResources, **kwargs)

    def describe_transform_job(self, TransformJobName):
        """Describe the transform job.

        Args:
          TransformJobName:

        Returns:

        """
        if TransformJobName not in LocalSagemakerClient._transform_jobs:
            error_response = {
                "Error": {
                    "Code": "ValidationException",
                    "Message": "Could not find local transform job",
                }
            }
            raise ClientError(error_response, "describe_transform_job")
        return LocalSagemakerClient._transform_jobs[TransformJobName].describe()

    def create_model(
        self, ModelName, PrimaryContainer, *args, **kwargs
    ):  # pylint: disable=unused-argument
        """Create a Local Model Object.

        Args:
          ModelName (str): the Model Name
          PrimaryContainer (dict): a SageMaker primary container definition
          *args:
          **kwargs:

        Returns:
        """
        LocalSagemakerClient._models[ModelName] = _LocalModel(ModelName, PrimaryContainer)

    def describe_model(self, ModelName):
        """Describe the model.

        Args:
          ModelName:

        Returns:
        """
        if ModelName not in LocalSagemakerClient._models:
            error_response = {
                "Error": {"Code": "ValidationException", "Message": "Could not find local model"}
            }
            raise ClientError(error_response, "describe_model")
        return LocalSagemakerClient._models[ModelName].describe()

    def describe_endpoint_config(self, EndpointConfigName):
        """Describe the endpoint configuration.

        Args:
          EndpointConfigName:

        Returns:

        """
        if EndpointConfigName not in LocalSagemakerClient._endpoint_configs:
            error_response = {
                "Error": {
                    "Code": "ValidationException",
                    "Message": "Could not find local endpoint config",
                }
            }
            raise ClientError(error_response, "describe_endpoint_config")
        return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()

    def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
        """Create the endpoint configuration.

        Args:
          EndpointConfigName:
          ProductionVariants:
          Tags:  (Default value = None)

        Returns:

        """
        LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
            EndpointConfigName, ProductionVariants, Tags
        )

    def describe_endpoint(self, EndpointName):
        """Describe the endpoint.

        Args:
          EndpointName:

        Returns:

        """
        if EndpointName not in LocalSagemakerClient._endpoints:
            error_response = {
                "Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"}
            }
            raise ClientError(error_response, "describe_endpoint")
        return LocalSagemakerClient._endpoints[EndpointName].describe()

    def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
        """Create the endpoint.

        Args:
          EndpointName:
          EndpointConfigName:
          Tags:  (Default value = None)

        Returns:

        """
        endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)
        LocalSagemakerClient._endpoints[EndpointName] = endpoint
        endpoint.serve()

    def update_endpoint(self, EndpointName, EndpointConfigName):  # pylint: disable=unused-argument
        """Update the endpoint.

        Args:
          EndpointName:
          EndpointConfigName:

        Returns:

        """
        raise NotImplementedError("Update endpoint name is not supported in local session.")

    def delete_endpoint(self, EndpointName):
        """Delete the endpoint.

        Args:
          EndpointName:

        Returns:

        """
        if EndpointName in LocalSagemakerClient._endpoints:
            LocalSagemakerClient._endpoints[EndpointName].stop()

    def delete_endpoint_config(self, EndpointConfigName):
        """Delete the endpoint configuration.

        Args:
          EndpointConfigName:

        Returns:

        """
        if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
            del LocalSagemakerClient._endpoint_configs[EndpointConfigName]

    def delete_model(self, ModelName):
        """Delete the model.

        Args:
          ModelName:

        Returns:

        """
        if ModelName in LocalSagemakerClient._models:
            del LocalSagemakerClient._models[ModelName]

    def create_pipeline(
        self, pipeline, pipeline_description, **kwargs  # pylint: disable=unused-argument
    ):
        """Create a local pipeline.

        Args:
            pipeline (Pipeline): Pipeline object
            pipeline_description (str): Description of the pipeline

        Returns:
            Pipeline metadata (PipelineArn)

        """
        local_pipeline = _LocalPipeline(
            pipeline=pipeline,
            pipeline_description=pipeline_description,
            local_session=self.sagemaker_session,
        )
        LocalSagemakerClient._pipelines[pipeline.name] = local_pipeline
        return {"PipelineArn": pipeline.name}

    def update_pipeline(
        self, pipeline, pipeline_description, **kwargs  # pylint: disable=unused-argument
    ):
        """Update a local pipeline.

        Args:
            pipeline (Pipeline): Pipeline object
            pipeline_description (str): Description of the pipeline

        Returns:
            Pipeline metadata (PipelineArn)

        """
        if pipeline.name not in LocalSagemakerClient._pipelines:
            error_response = {
                "Error": {
                    "Code": "ResourceNotFound",
                    "Message": "Pipeline {} does not exist".format(pipeline.name),
                }
            }
            raise ClientError(error_response, "update_pipeline")
        LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description
        LocalSagemakerClient._pipelines[pipeline.name].pipeline = pipeline
        LocalSagemakerClient._pipelines[
            pipeline.name
        ].last_modified_time = datetime.now().timestamp()
        return {"PipelineArn": pipeline.name}

    def describe_pipeline(self, PipelineName):
        """Describe the pipeline.

        Args:
          PipelineName (str):

        Returns:
            Pipeline metadata (PipelineArn, PipelineDefinition, LastModifiedTime, etc)

        """
        if PipelineName not in LocalSagemakerClient._pipelines:
            error_response = {
                "Error": {
                    "Code": "ResourceNotFound",
                    "Message": "Pipeline {} does not exist".format(PipelineName),
                }
            }
            raise ClientError(error_response, "describe_pipeline")
        return LocalSagemakerClient._pipelines[PipelineName].describe()

    def delete_pipeline(self, PipelineName):
        """Delete the local pipeline.

        Args:
          PipelineName (str):

        Returns:
            Pipeline metadata (PipelineArn)

        """
        if PipelineName in LocalSagemakerClient._pipelines:
            del LocalSagemakerClient._pipelines[PipelineName]
        return {"PipelineArn": PipelineName}

    def start_pipeline_execution(self, PipelineName, **kwargs):
        """Start the pipeline.

        Args:
          PipelineName (str):

        Returns: _LocalPipelineExecution object

        """
        if "ParallelismConfiguration" in kwargs:
            logger.warning("Parallelism configuration is not supported in local mode.")
        if "SelectiveExecutionConfig" in kwargs:
            raise ValueError("SelectiveExecutionConfig is not supported in local mode.")
        if PipelineName not in LocalSagemakerClient._pipelines:
            error_response = {
                "Error": {
                    "Code": "ResourceNotFound",
                    "Message": "Pipeline {} does not exist".format(PipelineName),
                }
            }
            raise ClientError(error_response, "start_pipeline_execution")
        return LocalSagemakerClient._pipelines[PipelineName].start(**kwargs)


class LocalSagemakerRuntimeClient(object):
    """A SageMaker Runtime client that calls a local endpoint only."""

    def __init__(self, config=None):
        """Initializes a LocalSageMakerRuntimeClient.

        Args:
            config (dict): Optional configuration for this client. In particular only
                the local port is read.
        """
        try:
            import urllib3
        except ImportError as e:
            logger.error(_module_import_error("urllib3", "Local mode", "local"))
            raise e

        self.http = urllib3.PoolManager()
        self.serving_port = 8080
        self.config = config
        self.serving_port = get_config_value("local.serving_port", config) or 8080

    def invoke_endpoint(
        self,
        Body,
        EndpointName,  # pylint: disable=unused-argument
        ContentType=None,
        Accept=None,
        CustomAttributes=None,
        TargetModel=None,
        TargetVariant=None,
        InferenceId=None,
    ):
        """Invoke the endpoint.

        Args:
            Body: Input data for which you want the model to provide inference.
            EndpointName: The name of the endpoint that you specified when you
                created the endpoint using the CreateEndpoint API.
            ContentType: The MIME type of the input data in the request body (Default value = None)
            Accept: The desired MIME type of the inference in the response (Default value = None)
            CustomAttributes: Provides additional information about a request for an inference
                submitted to a model hosted at an Amazon SageMaker endpoint (Default value = None)
            TargetModel: The model to request for inference when invoking a multi-model endpoint
                (Default value = None)
            TargetVariant: Specify the production variant to send the inference request to when
                invoking an endpoint that is running two or more variants (Default value = None)
            InferenceId: If you provide a value, it is added to the captured data when you enable
               data capture on the endpoint (Default value = None)

        Returns:
            object: Inference for the given input.
        """
        url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port)
        headers = {}

        if ContentType is not None:
            headers["Content-type"] = ContentType

        if Accept is not None:
            headers["Accept"] = Accept

        if CustomAttributes is not None:
            headers["X-Amzn-SageMaker-Custom-Attributes"] = CustomAttributes

        if TargetModel is not None:
            headers["X-Amzn-SageMaker-Target-Model"] = TargetModel

        if TargetVariant is not None:
            headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant

        if InferenceId is not None:
            headers["X-Amzn-SageMaker-Inference-Id"] = InferenceId

        # The http client encodes all strings using latin-1, which is not what we want.
        if isinstance(Body, str):
            Body = Body.encode("utf-8")
        r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)

        return {"Body": r, "ContentType": Accept}


class LocalSession(Session):
    """A SageMaker ``Session`` class for Local Mode.

    This class provides alternative Local Mode implementations for the functionality of
    :class:`~sagemaker.session.Session`.
    """

    def __init__(
        self,
        boto_session=None,
        default_bucket=None,
        s3_endpoint_url=None,
        disable_local_code=False,
        sagemaker_config: dict = None,
        default_bucket_prefix=None,
    ):
        """Create a Local SageMaker Session.

        Args:
            boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
                calls are delegated to (default: None). If not provided, one is created with
                default AWS configuration chain.
            s3_endpoint_url (str): Override the default endpoint URL for Amazon S3, if set
                (default: None).
            disable_local_code (bool): Set ``True`` to override the default AWS configuration
                chain to disable the ``local.local_code`` setting, which may not be supported for
                some SDK features (default: False).
            sagemaker_config: A dictionary containing default values for the
                SageMaker Python SDK. (default: None). The dictionary must adhere to the schema
                defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`.
                If sagemaker_config is not provided and configuration files exist (at the default
                paths for admins and users, or paths set through the environment variables
                SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE),
                a new dictionary will be generated from those configuration files. Alternatively,
                this dictionary can be generated by calling
                :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the
                Session.
            default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When
                objects are saved to the Session's default_bucket, the Object Key used will
                start with the default_bucket_prefix. If not provided here or within
                sagemaker_config, no additional prefix will be added.
        """
        self.s3_endpoint_url = s3_endpoint_url
        # We use this local variable to avoid disrupting the __init__->_initialize API of the
        # parent class... But overwriting it after constructor won't do anything, so prefix _ to
        # discourage external use:
        self._disable_local_code = disable_local_code

        super(LocalSession, self).__init__(
            boto_session=boto_session,
            default_bucket=default_bucket,
            sagemaker_config=sagemaker_config,
            default_bucket_prefix=default_bucket_prefix,
        )

        if platform.system() == "Windows":
            logger.warning("Windows Support for Local Mode is Experimental")

    def _initialize(
        self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs
    ):  # pylint: disable=unused-argument
        """Initialize this Local SageMaker Session.

        Args:
          boto_session:
          sagemaker_client:
          sagemaker_runtime_client:
          kwargs:

        Returns:

        """

        if boto_session is None:
            self.boto_session = boto3.Session()
        else:
            self.boto_session = boto_session

        self._region_name = self.boto_session.region_name

        if self._region_name is None:
            raise ValueError(
                "Must setup local AWS configuration with a region supported by SageMaker."
            )

        self.sagemaker_client = LocalSagemakerClient(self)
        self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
        self.local_mode = True
        sagemaker_config = kwargs.get("sagemaker_config", None)
        if sagemaker_config:
            validate_sagemaker_config(sagemaker_config)

        if self.s3_endpoint_url is not None:
            self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
            self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
            self.sagemaker_config = (
                sagemaker_config
                if sagemaker_config
                else load_sagemaker_config(s3_resource=self.s3_resource)
            )
        else:
            self.sagemaker_config = (
                sagemaker_config if sagemaker_config else load_sagemaker_config()
            )

        sagemaker_config = kwargs.get("sagemaker_config", None)
        if sagemaker_config:
            validate_sagemaker_config(sagemaker_config)
            self.sagemaker_config = sagemaker_config
        else:
            # self.s3_resource might be None. If it is None, load_sagemaker_config will
            # create a default S3 resource, but only if it needs to fetch from S3
            self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)

        # after sagemaker_config initialization, update self._default_bucket_name_override if needed
        self._default_bucket_name_override = resolve_value_from_config(
            direct_input=self._default_bucket_name_override,
            config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
            sagemaker_session=self,
        )
        # after sagemaker_config initialization, update self.default_bucket_prefix if needed
        self.default_bucket_prefix = resolve_value_from_config(
            direct_input=self.default_bucket_prefix,
            config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
            sagemaker_session=self,
        )

        local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
        if os.path.exists(local_mode_config_file):
            try:
                import yaml
            except ImportError as e:
                logger.error(_module_import_error("yaml", "Local mode", "local"))
                raise e

            self.config = yaml.safe_load(open(local_mode_config_file, "r"))
            if self._disable_local_code and "local" in self.config:
                self.config["local"]["local_code"] = False

    def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
        """A no-op method meant to override the sagemaker client.

        Args:
          job_name:
          wait:  (Default value = False)
          poll:  (Default value = 5)

        Returns:

        """
        # override logs_for_job() as it doesn't need to perform any action
        # on local mode.
        pass  # pylint: disable=unnecessary-pass

    def logs_for_processing_job(self, job_name, wait=False, poll=10):
        """A no-op method meant to override the sagemaker client.

        Args:
          job_name:
          wait:  (Default value = False)
          poll:  (Default value = 10)

        Returns:

        """
        # override logs_for_job() as it doesn't need to perform any action
        # on local mode.
        pass  # pylint: disable=unnecessary-pass


class file_input(object):
    """Amazon SageMaker channel configuration for FILE data sources, used in local mode."""

    def __init__(self, fileUri, content_type=None):
        """Create a definition for input data used by an SageMaker training job in local mode."""
        self.config = {
            "DataSource": {
                "FileDataSource": {
                    "FileDataDistributionType": "FullyReplicated",
                    "FileUri": fileUri,
                }
            }
        }

        if content_type is not None:
            self.config["ContentType"] = content_type