# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import tempfile
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, Iterable

from braket.aws.aws_session import AwsSession
from braket.jobs.local.local_job_container import _LocalJobContainer


def setup_container(
    container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs
) -> Dict[str, str]:
    """Sets up a container with prerequisites for running a Braket Job. The prerequisites are
    based on the options the customer has chosen for the job. Similarly, any environment variables
    that are needed during runtime will be returned by this function.

    Args:
        container(_LocalJobContainer): The container that will run the braket job.
        aws_session (AwsSession): AwsSession for connecting to AWS Services.

    Returns:
        Dict[str, str]: A dictionary of environment variables that reflect Braket Jobs options
        requested by the customer.
    """
    logger = getLogger(__name__)
    _create_expected_paths(container, **creation_kwargs)
    run_environment_variables = {}
    run_environment_variables.update(_get_env_credentials(aws_session, logger))
    run_environment_variables.update(
        _get_env_script_mode_config(creation_kwargs["algorithmSpecification"]["scriptModeConfig"])
    )
    run_environment_variables.update(_get_env_default_vars(aws_session, **creation_kwargs))
    if _copy_hyperparameters(container, **creation_kwargs):
        run_environment_variables.update(_get_env_hyperparameters())
    if _copy_input_data_list(container, aws_session, **creation_kwargs):
        run_environment_variables.update(_get_env_input_data())
    return run_environment_variables


def _create_expected_paths(container: _LocalJobContainer, **creation_kwargs) -> None:
    """Creates the basic paths required for Braket Jobs to run.

    Args:
        container(_LocalJobContainer): The container that will run the braket job.
    """
    container.makedir("/opt/ml/model")
    container.makedir(creation_kwargs["checkpointConfig"]["localPath"])


def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> Dict[str, str]:
    """Gets the account credentials from boto so they can be added as environment variables to
    the running container.

    Args:
        aws_session (AwsSession): AwsSession for connecting to AWS Services.
        logger (Logger): Logger object with which to write logs. Default is `getLogger(__name__)`

    Returns:
        Dict[str, str]: The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    credentials = aws_session.boto_session.get_credentials()
    if credentials.token is None:
        logger.info("Using the long-lived AWS credentials found in session")
        return {
            "AWS_ACCESS_KEY_ID": str(credentials.access_key),
            "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key),
        }
    logger.warning(
        "Using the short-lived AWS credentials found in session. They might expire while running."
    )
    return {
        "AWS_ACCESS_KEY_ID": str(credentials.access_key),
        "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key),
        "AWS_SESSION_TOKEN": str(credentials.token),
    }


def _get_env_script_mode_config(script_mode_config: Dict[str, str]) -> Dict[str, str]:
    """Gets the environment variables related to the customer script mode config.

    Args:
        script_mode_config (Dict[str, str]): The values for scriptModeConfig in the boto3 input
            parameters for running a Braket Job.

    Returns:
        Dict[str, str]: The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    result = {
        "AMZN_BRAKET_SCRIPT_S3_URI": script_mode_config["s3Uri"],
        "AMZN_BRAKET_SCRIPT_ENTRY_POINT": script_mode_config["entryPoint"],
    }
    if "compressionType" in script_mode_config:
        result["AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE"] = script_mode_config["compressionType"]
    return result


def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs) -> Dict[str, str]:
    """This function gets the remaining 'simple' env variables, that don't require any
     additional logic to determine what they are or when they should be added as env variables.

    Args:
        aws_session (AwsSession): AwsSession for connecting to AWS Services.

    Returns:
        Dict[str, str]: The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    job_name = creation_kwargs["jobName"]
    bucket, location = AwsSession.parse_s3_uri(creation_kwargs["outputDataConfig"]["s3Path"])
    return {
        "AWS_DEFAULT_REGION": aws_session.region,
        "AMZN_BRAKET_JOB_NAME": job_name,
        "AMZN_BRAKET_DEVICE_ARN": creation_kwargs["deviceConfig"]["device"],
        "AMZN_BRAKET_JOB_RESULTS_DIR": "/opt/braket/model",
        "AMZN_BRAKET_CHECKPOINT_DIR": creation_kwargs["checkpointConfig"]["localPath"],
        "AMZN_BRAKET_OUT_S3_BUCKET": bucket,
        "AMZN_BRAKET_TASK_RESULTS_S3_URI": f"s3://{bucket}/jobs/{job_name}/tasks",
        "AMZN_BRAKET_JOB_RESULTS_S3_PATH": str(Path(location, job_name, "output").as_posix()),
    }


def _get_env_hyperparameters() -> Dict[str, str]:
    """Gets the env variable for hyperparameters. This should only be added if the customer has
    provided hyperpameters to the job.

    Returns:
        Dict[str, str]: The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    return {
        "AMZN_BRAKET_HP_FILE": "/opt/braket/input/config/hyperparameters.json",
    }


def _get_env_input_data() -> Dict[str, str]:
    """Gets the env variable for input data. This should only be added if the customer has
    provided input data to the job.

    Returns:
        Dict[str, str]: The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    return {
        "AMZN_BRAKET_INPUT_DIR": "/opt/braket/input/data",
    }


def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs) -> bool:
    """If hyperpameters are present, this function will store them as a JSON object in the
     container in the appropriate location on disk.

    Args:
        container(_LocalJobContainer): The container to save hyperparameters to.

    Returns:
        bool: True if any hyperparameters were copied to the container.
    """
    if "hyperParameters" not in creation_kwargs:
        return False
    hyperparameters = creation_kwargs["hyperParameters"]
    with tempfile.TemporaryDirectory() as temp_dir:
        file_path = Path(temp_dir, "hyperparameters.json")
        with open(file_path, "w") as write_file:
            json.dump(hyperparameters, write_file)
        container.copy_to(str(file_path), "/opt/ml/input/config/hyperparameters.json")
    return True


def _download_input_data(
    aws_session: AwsSession,
    download_dir: str,
    input_data: Dict[str, Any],
) -> None:
    """Downloads input data for a job.

    Args:
        aws_session (AwsSession): AwsSession for connecting to AWS Services.
        download_dir (str): The directory path to download to.
        input_data (Dict[str, Any]): One of the input data in the boto3 input parameters for
            running a Braket Job.
    """
    # If s3 prefix is the full name of a directory and all keys are inside
    # that directory, the contents of said directory will be copied into a
    # directory with the same name as the channel. This behavior is the same
    # whether or not s3 prefix ends with a "/". Moreover, if s3 prefix ends
    # with a "/", this is certainly the behavior to expect, since it can only
    # match a directory.
    # If s3 prefix matches any files exactly, or matches as a prefix of any
    # files or directories, then all files and directories matching s3 prefix
    # will be copied into a directory with the same name as the channel.
    channel_name = input_data["channelName"]
    s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"]
    bucket, prefix = AwsSession.parse_s3_uri(s3_uri_prefix)
    s3_keys = aws_session.list_keys(bucket, prefix)
    top_level = prefix if _is_dir(prefix, s3_keys) else str(Path(prefix).parent)
    found_item = False
    try:
        Path(download_dir, channel_name).mkdir()
    except FileExistsError:
        raise ValueError(f"Duplicate channel names not allowed for input data: {channel_name}")
    for s3_key in s3_keys:
        relative_key = Path(s3_key).relative_to(top_level)
        download_path = Path(download_dir, channel_name, relative_key)
        if not s3_key.endswith("/"):
            download_path.parent.mkdir(parents=True, exist_ok=True)
            aws_session.download_from_s3(
                AwsSession.construct_s3_uri(bucket, s3_key), str(download_path)
            )
            found_item = True
    if not found_item:
        raise RuntimeError(f"No data found for channel '{channel_name}'")


def _is_dir(prefix: str, keys: Iterable[str]) -> bool:
    """Determine whether the prefix refers to a directory.

    Args:
        prefix (str): The prefix to check.
        keys (Iterable[str]): The set of paths to check.

    Returns:
        bool: True if the prefix refers to a directory.
    """
    if prefix.endswith("/"):
        return True
    return all(key.startswith(f"{prefix}/") for key in keys)


def _copy_input_data_list(
    container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs
) -> bool:
    """If the input data list is not empty, this function will download the input files and
    store them in the container.

    Args:
        container (_LocalJobContainer): The container to save input data to.
        aws_session (AwsSession): AwsSession for connecting to AWS Services.

    Returns:
        bool: True if any input data was copied to the container.
    """
    if "inputDataConfig" not in creation_kwargs:
        return False

    input_data_list = creation_kwargs["inputDataConfig"]
    with tempfile.TemporaryDirectory() as temp_dir:
        for input_data in input_data_list:
            _download_input_data(aws_session, temp_dir, input_data)
        container.copy_to(temp_dir, "/opt/ml/input/data/")
    return bool(input_data_list)