# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class to ensure that ``framework_version`` is defined when constructing framework classes."""
from __future__ import absolute_import

import ast

from packaging.version import InvalidVersion, Version

from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

FRAMEWORK_ARG = "framework_version"
IMAGE_ARG = "image_uri"
PY_ARG = "py_version"

FRAMEWORK_DEFAULTS = {
    "Chainer": "4.1.0",
    "MXNet": "1.2.0",
    "PyTorch": "0.4.0",
    "SKLearn": "0.20.0",
    "TensorFlow": "1.11.0",
}

FRAMEWORK_CLASSES = list(FRAMEWORK_DEFAULTS.keys())

ESTIMATORS = {
    fw: ("sagemaker.{}".format(fw.lower()), "sagemaker.{}.estimator".format(fw.lower()))
    for fw in FRAMEWORK_CLASSES
}
# TODO: check for sagemaker.tensorflow.serving.Model
MODELS = {
    "{}Model".format(fw): (
        "sagemaker.{}".format(fw.lower()),
        "sagemaker.{}.model".format(fw.lower()),
    )
    for fw in FRAMEWORK_CLASSES
}


class FrameworkVersionEnforcer(Modifier):
    """Ensures that ``framework_version`` is defined when instantiating a framework estimator."""

    def node_should_be_modified(self, node):
        """Checks if the ast.Call node instantiates a framework estimator or model.

        It doesn't specify the ``framework_version`` and ``py_version`` parameter,
        as appropriate.

        This looks for the following formats:

        - ``TensorFlow``
        - ``sagemaker.tensorflow.TensorFlow``

        where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow.

        Args:
            node (ast.Call): a node that represents a function call. For more,
                see https://docs.python.org/3/library/ast.html#abstract-grammar.

        Returns:
            bool: If the ``ast.Call`` is instantiating a framework class that
                should specify ``framework_version``, but doesn't.
        """
        if matching.matches_any(node, ESTIMATORS) or matching.matches_any(node, MODELS):
            return _version_args_needed(node)

        return False

    def modify_node(self, node):
        """Modifies the ``ast.Call`` node's keywords to include ``framework_version``.

        The ``framework_version`` value is determined by the framework:

        - Chainer: "4.1.0"
        - MXNet: "1.2.0"
        - PyTorch: "0.4.0"
        - SKLearn: "0.20.0"
        - TensorFlow: "1.11.0"

        The ``py_version`` value is determined by the framework, framework_version, and if it is a
        model, whether the model accepts a py_version

        Args:
            node (ast.Call): a node that represents the constructor of a framework class.

        Returns:
            ast.AST: the original node, which has been potentially modified.
        """
        framework, is_model = _framework_from_node(node)

        # if framework_version is not supplied, get default and append keyword
        if matching.has_arg(node, FRAMEWORK_ARG):
            framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
        else:
            framework_version = FRAMEWORK_DEFAULTS[framework]
            node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version)))

        # if py_version is not supplied, get a conditional default, and if not None, append keyword
        if not matching.has_arg(node, PY_ARG):
            py_version = _py_version_defaults(framework, framework_version, is_model)
            if py_version:
                node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
        return node


def _py_version_defaults(framework, framework_version, is_model=False):
    """Gets the py_version required for the framework_version and if it's a model

    Args:
        framework (str): name of the framework
        framework_version (str): version of the framework
        is_model (bool): whether it is a constructor for a model or not

    Returns:
        str: the default py version, as appropriate. None if no default py_version
    """
    if framework in ("Chainer", "PyTorch"):
        return "py3"
    if framework == "SKLearn" and not is_model:
        return "py3"
    if framework == "MXNet":
        return "py2"
    if framework == "TensorFlow" and not is_model:
        return _tf_py_version_default(framework_version)
    return None


def _tf_py_version_default(framework_version):
    """Gets the py_version default based on framework_version for TensorFlow."""
    if not framework_version:
        return "py2"

    try:
        version = Version(framework_version)
    except InvalidVersion:
        return "py2"

    if version < Version("1.12"):
        return "py2"
    if version < Version("2.2"):
        return "py3"
    return "py37"


def _framework_from_node(node):
    """Retrieves the framework class name based on the function call, and if it was a model

    Args:
        node (ast.Call): a node that represents the constructor of a framework class.
            This can represent either <Framework> or sagemaker.<framework>.<Framework>.

    Returns:
        str, bool: the (capitalized) framework class name, and if it is a model class
    """
    if isinstance(node.func, ast.Name):
        framework = node.func.id
    elif isinstance(node.func, ast.Attribute):
        framework = node.func.attr
    else:
        framework = ""

    is_model = framework.endswith("Model")
    if is_model:
        framework = framework[: framework.find("Model")]

    return framework, is_model


def _version_args_needed(node):
    """Determines if image_arg or version_arg was supplied

    Applies similar logic as ``validate_version_or_image_args``
    """
    # if image_arg is present, no need to supply version arguments
    if matching.has_arg(node, IMAGE_ARG):
        return False

    # if framework_version is None, need args
    if matching.has_arg(node, FRAMEWORK_ARG):
        framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
    else:
        return True

    # check if we expect py_version and we don't get it -- framework and model dependent
    framework, is_model = _framework_from_node(node)
    expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
    if expecting_py_version:
        return not matching.has_arg(node, PY_ARG)

    return False