# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The SageMaker JumpStart Industry utils module."""
from __future__ import absolute_import

import re
import os
import json
from typing import Callable
import pandas as pd
from smjsindustry.finance.constants import (
    IMAGE_CONFIG_FILE,
    ECR_URI_TEMPLATE,
    REPOSITORY,
    CONTAINER_IMAGE_VERSION,
)


def _get_freq_label_by_day(date_value: str) -> str:
    """Gets frequency label for the date value which is aggregated by day.

    Args:
        date_value (str): The date value.

    Returns:
        str: The date value aggregated by day.
    """
    if not bool(re.match(r"^\d{4}-\d{1,2}-\d{1,2}$", date_value)):
        raise ValueError("Date needs to be in yyyy-mm-dd format when freq is D")
    return date_value


def _get_freq_label_by_week(date_value: str) -> str:
    """Gets frequency label for the date value which is aggregated by week.

    Args:
        date_value (str): The date value.

    Returns:
        str: The date value aggregated by week.
    """
    if bool(re.match(r"^\d{4}W\d{1,2}$", date_value)):
        return date_value
    if not bool(re.match(r"^\d{4}-\d{1,2}-\d{1,2}$", date_value)):
        raise ValueError("Date needs to be in yyyy-mm-dd format when freq is W")
    ts = pd.Timestamp(date_value)
    return "{}W{}".format(ts.year, ts.week)


def _get_freq_label_by_month(date_value: str) -> str:
    """Gets frequency label for the date value which is aggregated by month.

    Args:
        date_value (str): The date value.

    Returns:
        str: The date value aggregated by month.
    """
    if bool(re.match(r"^\d{4}M\d{1,2}$", date_value)):
        return date_value
    if not bool(re.match(r"^\d{4}-\d{1,2}(-\d{1,2})?$", date_value)):
        raise ValueError("Date needs to be in yyyy-mm-dd or yyyy-mm format when freq is M")
    ts = pd.Timestamp(date_value)
    return "{}M{}".format(ts.year, ts.month)


def _get_freq_label_by_quarter(date_value: str) -> str:
    """Gets frequency label for the date value which is aggregated by quarter.

    Args:
        date_value (str): The date value.

    Returns:
        str: The date value aggregated by quarter.
    """
    if bool(re.match(r"^\d{4}Q\d{1,2}$", date_value)):
        return date_value
    if not bool(re.match(r"^\d{4}-\d{1,2}(-\d{1,2})?$", date_value)):
        raise ValueError("Date needs to be in yyyy-mm-dd or yyyy-mm format when freq is Q")
    ts = pd.Timestamp(date_value)
    return "{}Q{}".format(ts.year, ts.quarter)


def _get_freq_label_by_year(date_value: str) -> str:
    """Gets frequency label for the date value which is aggregated by year.

    Args:
        date_value (str): The date value.

    Returns:
        str: The date value aggregated by year.
    """
    if bool(re.match(r"^\d{4}$", date_value)):
        return date_value
    if not bool(re.match(r"^\d{4}(-\d{1,2}){0,2}$", date_value)):
        raise ValueError("Date needs to be in yyyy-mm-dd, yyyy-mm or yyyy format when freq is Y")
    ts = pd.Timestamp(date_value)
    return str(ts.year)


FREQ_LABEL_MAP = {
    "D": _get_freq_label_by_day,
    "W": _get_freq_label_by_week,
    "M": _get_freq_label_by_month,
    "Q": _get_freq_label_by_quarter,
    "Y": _get_freq_label_by_year,
}


def get_freq_label(date_value: str, freq: str) -> Callable:
    """Gets frequency label for the date value.

    Args:
        date_value (str): The date value.
        freq (str): The frequency value specifies how the date field should be aggregated,
            by year, quarter, month, week, day. Available values:
            ``{'Y', 'Q', 'M', 'W', 'D'}``, default ``'Q'``.

    Returns:
        python function: The function call to get date aggregated by certain frequency.
    """
    freq = freq.upper()
    if freq not in FREQ_LABEL_MAP:
        raise ValueError("frequency {} not supported".format(freq))
    if not isinstance(date_value, str):
        raise Exception("The date column needs to be string")
    return FREQ_LABEL_MAP[freq](date_value.upper())


def load_image_uri_config():
    """Loads the JSON config for the image URI.

    Returns:
        JSON object: The JSON object of the image URI config.
    """
    fname = os.path.join(os.path.dirname(__file__), IMAGE_CONFIG_FILE)
    with open(fname) as f:
        return json.load(f)


def retrieve_image(region):
    """Retrieves the Amazon ECR image URI for the Docker image matching the given region.

    Args:
        region (str): The AWS Region.

    Returns:
        str: the Amazon ECR image URI for the corresponding Docker image.
    """
    config = load_image_uri_config()
    account_id = config[region]
    repository = "{}:{}".format(REPOSITORY, CONTAINER_IMAGE_VERSION)
    return ECR_URI_TEMPLATE.format(account_id=account_id, region=region, repository=repository)