# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import os
from enum import Enum
from typing import Dict


class Framework(str, Enum):
    """Supported Frameworks for pre-built containers"""

    BASE = "BASE"
    PL_TENSORFLOW = "PL_TENSORFLOW"
    PL_PYTORCH = "PL_PYTORCH"


def retrieve_image(framework: Framework, region: str) -> str:
    """Retrieves the ECR URI for the Docker image matching the specified arguments.

    Args:
        framework (Framework): The name of the framework.
        region (str): The AWS region for the Docker image.

    Returns:
        str: The ECR URI for the corresponding Amazon Braket Docker image.

    Raises:
        ValueError: If any of the supplied values are invalid or the combination of inputs
            specified is not supported.
    """
    # Validate framework
    framework = Framework(framework)
    config = _config_for_framework(framework)
    registry = _registry_for_region(config, region)
    tag = f"{config['repository']}:latest"
    return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}"


def _config_for_framework(framework: Framework) -> Dict[str, str]:
    """Loads the JSON config for the given framework.

    Args:
        framework (Framework): The framework whose config needs to be loaded.

    Returns:
        Dict[str, str]: Dict that contains the configuration for the specified framework.
    """
    fname = os.path.join(os.path.dirname(__file__), "image_uri_config", f"{framework.lower()}.json")
    with open(fname) as f:
        return json.load(f)


def _registry_for_region(config: Dict[str, str], region: str) -> str:
    """Retrieves the registry for the specified region from the configuration.

    Args:
        config (Dict[str, str]): Dict containing the framework configuration.
        region (str): str that specifies the region for which the registry is retrieved.

    Returns:
        str: str that specifies the registry for the supplied region.

    Raises:
        ValueError: If the supplied region is invalid or not supported.
    """
    if region not in (supported_regions := config["supported_regions"]):
        raise ValueError(
            f"Unsupported region: {region}. You may need to upgrade your SDK version for newer "
            f"regions. Supported region(s): {supported_regions}"
        )
    return config["registry"]