# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import

import os
import platform
import sys
import tempfile
from abc import ABCMeta
from abc import abstractmethod
from six import with_metaclass

from six.moves.urllib.parse import urlparse

import sagemaker.amazon.common
import sagemaker.local.utils
import sagemaker.utils


def get_data_source_instance(data_source, sagemaker_session):
    """Return an Instance of :class:`sagemaker.local.data.DataSource`.

    The instance can handle the provided data_source URI.

    data_source can be either file:// or s3://

    Args:
        data_source (str): a valid URI that points to a data source.
        sagemaker_session (:class:`sagemaker.session.Session`): a SageMaker Session to
            interact with S3 if required.

    Returns:
        sagemaker.local.data.DataSource: an Instance of a Data Source

    Raises:
        ValueError: If parsed_uri scheme is neither `file` nor `s3` , raise an
            error.
    """
    parsed_uri = urlparse(data_source)
    if parsed_uri.scheme == "file":
        return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
    if parsed_uri.scheme == "s3":
        return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
    raise ValueError(
        "data_source must be either file or s3. parsed_uri.scheme: {}".format(parsed_uri.scheme)
    )


def get_splitter_instance(split_type):
    """Return an Instance of :class:`sagemaker.local.data.Splitter`.

    The instance returned is according to the specified `split_type`.

    Args:
        split_type (str): either 'Line' or 'RecordIO'. Can be left as None to
            signal no data split will happen.

    Returns
        :class:`sagemaker.local.data.Splitter`: an Instance of a Splitter
    """
    if split_type is None:
        return NoneSplitter()
    if split_type == "Line":
        return LineSplitter()
    if split_type == "RecordIO":
        return RecordIOSplitter()
    raise ValueError("Invalid Split Type: %s" % split_type)


def get_batch_strategy_instance(strategy, splitter):
    """Return an Instance of :class:`sagemaker.local.data.BatchStrategy` according to `strategy`

    Args:
        strategy (str): Either 'SingleRecord' or 'MultiRecord'
        splitter (:class:`sagemaker.local.data.Splitter): splitter to get the data from.

    Returns
        :class:`sagemaker.local.data.BatchStrategy`: an Instance of a BatchStrategy
    """
    if strategy == "SingleRecord":
        return SingleRecordStrategy(splitter)
    if strategy == "MultiRecord":
        return MultiRecordStrategy(splitter)
    raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"')


class DataSource(with_metaclass(ABCMeta, object)):
    """Placeholder docstring"""

    @abstractmethod
    def get_file_list(self):
        """Retrieve the list of absolute paths to all the files in this data source.

        Returns:
            List[str]: List of absolute paths.
        """

    @abstractmethod
    def get_root_dir(self):
        """Retrieve the absolute path to the root directory of this data source.

        Returns:
            str: absolute path to the root directory of this data source.
        """


class LocalFileDataSource(DataSource):
    """Represents a data source within the local filesystem."""

    def __init__(self, root_path):
        super(LocalFileDataSource, self).__init__()

        self.root_path = os.path.abspath(root_path)
        if not os.path.exists(self.root_path):
            raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)

    def get_file_list(self):
        """Retrieve the list of absolute paths to all the files in this data source.

        Returns:
            List[str] List of absolute paths.
        """
        if os.path.isdir(self.root_path):
            return [
                os.path.join(self.root_path, f)
                for f in os.listdir(self.root_path)
                if os.path.isfile(os.path.join(self.root_path, f))
            ]
        return [self.root_path]

    def get_root_dir(self):
        """Retrieve the absolute path to the root directory of this data source.

        Returns:
            str: absolute path to the root directory of this data source.
        """
        if os.path.isdir(self.root_path):
            return self.root_path
        return os.path.dirname(self.root_path)


class S3DataSource(DataSource):
    """Defines a data source given by a bucket and S3 prefix.

    The contents will be downloaded and then processed as local data.
    """

    def __init__(self, bucket, prefix, sagemaker_session):
        """Create an S3DataSource instance.

        Args:
            bucket (str): S3 bucket name
            prefix (str): S3 prefix path to the data
            sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker_session with the
            desired settings
                to talk to S3
        """
        super(S3DataSource, self).__init__()

        # Create a temporary dir to store the S3 contents
        root_dir = sagemaker.utils.get_config_value(
            "local.container_root", sagemaker_session.config
        )
        if root_dir:
            root_dir = os.path.abspath(root_dir)

        working_dir = tempfile.mkdtemp(dir=root_dir)
        # Docker cannot mount Mac OS /var folder properly see
        # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600
        # Only apply this workaround if the user didn't provide an alternate storage root dir.
        if root_dir is None and platform.system() == "Darwin":
            working_dir = "/private{}".format(working_dir)

        sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session)
        self.files = LocalFileDataSource(working_dir)

    def get_file_list(self):
        """Retrieve the list of absolute paths to all the files in this data source.

        Returns:
            List[str]: List of absolute paths.
        """
        return self.files.get_file_list()

    def get_root_dir(self):
        """Retrieve the absolute path to the root directory of this data source.

        Returns:
            str: absolute path to the root directory of this data source.
        """
        return self.files.get_root_dir()


class Splitter(with_metaclass(ABCMeta, object)):
    """Placeholder docstring"""

    @abstractmethod
    def split(self, file):
        """Split a file into records using a specific strategy

        Args:
            file (str): path to the file to split

        Returns:
            generator for the individual records that were split from the file
        """


class NoneSplitter(Splitter):
    """Does not split records, essentially reads the whole file."""

    # non-utf8 characters.
    _textchars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})

    def split(self, filename):
        """Split a file into records using a specific strategy.

        For this NoneSplitter there is no actual split happening and the file
        is returned as a whole.

        Args:
            filename (str): path to the file to split

        Returns: generator for the individual records that were split from
            the file
        """
        with open(filename, "rb") as f:
            buf = f.read()
            if not self._is_binary(buf):
                buf = buf.decode()
            yield buf

    def _is_binary(self, buf):
        """Check whether `buf` contains binary data.

        Returns True if `buf` contains any non-utf-8 characters.

        Args:
            buf (bytes): data to inspect

        Returns:
            True if data is binary, otherwise False
        """
        return bool(buf.translate(None, self._textchars))


class LineSplitter(Splitter):
    """Split records by new line."""

    def split(self, file):
        """Split a file into records using a specific strategy

        This LineSplitter splits the file on each line break.

        Args:
            file (str): path to the file to split

        Returns: generator for the individual records that were split from
        the file
        """
        with open(file, "r") as f:
            for line in f:
                yield line


class RecordIOSplitter(Splitter):
    """Split using Amazon Recordio.

    Not useful for string content.
    """

    def split(self, file):
        """Split a file into records using a specific strategy

        This RecordIOSplitter splits the data into individual RecordIO
        records.

        Args:
            file (str): path to the file to split

        Returns: generator for the individual records that were split from
        the file
        """
        with open(file, "rb") as f:
            for record in sagemaker.amazon.common.read_recordio(f):
                yield record


class BatchStrategy(with_metaclass(ABCMeta, object)):
    """Placeholder docstring"""

    def __init__(self, splitter):
        """Create a Batch Strategy Instance

        Args:
            splitter (sagemaker.local.data.Splitter): A Splitter to pre-process
                the data before batching.
        """
        self.splitter = splitter

    @abstractmethod
    def pad(self, file, size):
        """Group together as many records as possible to fit in the specified size.

        Args:
            file (str): file path to read the records from.
            size (int): maximum size in MB that each group of records will be
                fitted to. passing 0 means unlimited size.

        Returns:
            generator of records
        """


class MultiRecordStrategy(BatchStrategy):
    """Feed multiple records at a time for batch inference.

    Will group up as many records as possible within the payload specified.
    """

    def pad(self, file, size=6):
        """Group together as many records as possible to fit in the specified size.

        Args:
            file (str): file path to read the records from.
            size (int): maximum size in MB that each group of records will be
                fitted to. passing 0 means unlimited size.

        Returns:
            generator of records
        """
        buffer = ""
        for element in self.splitter.split(file):
            if _payload_size_within_limit(buffer + element, size):
                buffer += element
            else:
                tmp = buffer
                buffer = element
                yield tmp
        if _validate_payload_size(buffer, size):
            yield buffer


class SingleRecordStrategy(BatchStrategy):
    """Feed a single record at a time for batch inference.

    If a single record does not fit within the payload specified it will
    throw a RuntimeError.
    """

    def pad(self, file, size=6):
        """Group together as many records as possible to fit in the specified size.

        This SingleRecordStrategy will not group any record and will return
        them one by one as long as they are within the maximum size.

        Args:
            file (str): file path to read the records from.
            size (int): maximum size in MB that each group of records will be
                fitted to. passing 0 means unlimited size.

        Returns:
            generator of records
        """
        for element in self.splitter.split(file):
            if _validate_payload_size(element, size):
                yield element


def _payload_size_within_limit(payload, size):
    """Placeholder docstring."""
    size_in_bytes = size * 1024 * 1024
    if size == 0:
        return True
    return sys.getsizeof(payload) < size_in_bytes


def _validate_payload_size(payload, size):
    """Check if a payload is within the size in MB threshold.

    Raise an exception if the payload is beyond the size in MB threshold.

    Args:
        payload: data that will be checked
        size (int): max size in MB

    Returns:
        bool: True if within bounds. if size=0 it will always return True

    Raises:
        RuntimeError: If the payload is larger a runtime error is thrown.
    """

    if _payload_size_within_limit(payload, size):
        return True
    raise RuntimeError("Record is larger than %sMB. Please increase your max_payload" % size)