import tensorflow as tf

from typing import List, Dict
from pathlib import Path
from enum import Enum
from collections import OrderedDict
from .abstract_model_helper import ModelHelper


class TFModelFormat(Enum):
    FrozenGraphModel = "frozen_graph_model"
    SavedModel = "saved_model"


class TFModelHelper(ModelHelper):

    UNLIKELY_OUTPUT_TYPES = {"Const", "Assign", "NoOp", "Placeholder"}

    def __init__(self, model_path: str, data_shape: Dict[str, List[int]]) -> None:
        super(TFModelHelper, self).__init__(model_path)
        tf.enable_eager_execution()
        self.__input_tensor_names = []
        self.__output_tensor_names = []
        self.__input_tensors = []
        self.__output_tensors = []
        self.__model_tf_version = "not found"
        self.__data_shape = data_shape

    @property
    def model_type(self) -> bool:
        if self.model_path.is_file() and self.model_path.suffix in ['.pb', '.pbtxt']:
            return TFModelFormat.FrozenGraphModel
        elif self.model_path.is_dir() and self.model_path.joinpath('variables').exists():
            return TFModelFormat.SavedModel
        else:
            raise Exception(f"Encountered invalid model file format. {self.model_path.as_posix}")

    @property
    def input_tensors(self) -> List[tf.Tensor]:
        return self.__input_tensors

    @property
    def output_tensors(self) -> List[tf.Tensor]:
        return self.__output_tensors

    @property
    def input_tensor_names(self) -> List[str]:
        return self.__input_tensor_names

    @property
    def output_tensor_names(self) -> List[str]:
        return self.__output_tensor_names

    def __get_graph_from_frozen_graph_model(self) -> tf.Graph:
        with tf.gfile.GFile(self.model_path.as_posix(), 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            # Setting the name to empty string will ensure that the the prefix will be empty string.
            # it is not requred since we are not modifying the graph here. Prefix is only used for
            # distinguishing between the nodes of the imported graph and the modified nodes.
            tf.import_graph_def(graph_def, name="")
        return graph

    def __extract_input_and_output_tensors_from_frozen_graph(self) -> None:
        # https://github.com/neo-ai/neo-ai-dlr/blob/master/python/dlr/tf_model.py#L37
        tf.reset_default_graph()
        graph = self.__get_graph_from_frozen_graph_model()
        input_tensors = OrderedDict()
        output_tensors = OrderedDict()

        for op in graph.get_operations():
            if op.type == 'Placeholder' and op.inputs.__len__() == 0 and op.outputs.__len__() == 1:
                input_tensors[op.outputs[0].name] = op.outputs[0]

            if op.type not in self.UNLIKELY_OUTPUT_TYPES and op.outputs.__len__() == 1:
                output_tensors[op.outputs[0].name] = op.outputs[0]

        output_tensor_names = output_tensors.keys()

        for op in graph.get_operations():
            for in_t in op.inputs:
                if in_t.name in output_tensor_names:
                    output_tensors.pop(in_t.name)
            for cont_op in op.control_inputs:
                for out_t in cont_op.outputs:
                    if out_t.name in output_tensor_names:
                        output_tensors.pop(out_t.name)


        tf.reset_default_graph()

        self.__input_tensor_names = list(input_tensors.keys())
        self.__output_tensor_names = list(output_tensors.keys())
        self.__input_tensors = list(input_tensors.values())
        self.__output_tensors = list(output_tensors.values())

    def __get_tag_set(self) -> str:
        try:
            from tensorflow.contrib.saved_model.python.saved_model import reader
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import saved_model.reader which is "
                "required to get tag set from saved model.")

        tag_sets = reader.get_saved_model_tag_sets(self.model_path.as_posix())
        return tag_sets[0]

    def __extract_input_and_output_tensors_from_saved_model(self) -> None:
        # https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/frontend/tensorflow_parser.py#L73
        tf.reset_default_graph()

        tags = self.__get_tag_set()
        input_tensors = OrderedDict()
        output_tensors = OrderedDict()

        with tf.Session() as sess:
            meta_graph_def = tf.saved_model.loader.load(sess, tags, self.model_path.as_posix())
            for sig_def in meta_graph_def.signature_def.values():
                for input_tensor in sig_def.inputs.values():
                    input_tensors[input_tensor.name] = tf.get_default_graph().get_tensor_by_name(input_tensor.name)
                for output_tensor in sig_def.outputs.values():
                    output_tensors[output_tensor.name] = tf.get_default_graph().get_tensor_by_name(output_tensor.name)

        tf.reset_default_graph()

        self.__input_tensor_names = list(input_tensors.keys())
        self.__output_tensor_names = list(output_tensors.keys())
        self.__input_tensors = list(input_tensors.values())
        self.__output_tensors = list(output_tensors.values())

    def __extract_input_and_output_tensors_from_saved_model_v2(self) -> None:
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

        tags = self.__get_tag_set()
        loaded = tf.compat.v2.saved_model.load(self.model_path.as_posix(), tags=tags)
        for shape in self.__data_shape.values():
            tensor_spec = tf.TensorSpec(tuple(shape))
        if len(loaded.signatures) == 0:
            f = loaded.__call__.get_concrete_function(tensor_spec)
        elif 'serving_default' in loaded.signatures:
            f = loaded.signatures['serving_default']
        else:
            f = loaded.signatures[list(loaded.signatures.keys())[0]]
        frozen_func = convert_variables_to_constants_v2(f, lower_control_flow=True)

        for tensor in frozen_func.inputs:
            self.__input_tensor_names.append(tensor.name)
        for tensor in frozen_func.outputs:
            self.__output_tensor_names.append(tensor.name)

        self.__input_tensors = frozen_func.inputs
        self.__output_tensors = frozen_func.outputs

    def extract_input_and_output_tensors(self, user_shape_dict=None) -> None:
        if self.model_type == TFModelFormat.SavedModel:
            self.__extract_input_and_output_tensors_from_saved_model()
        else:
            self.__extract_input_and_output_tensors_from_frozen_graph()

    def extract_input_and_output_tensors_v2(self, user_shape_dict=None) -> None:
        if self.model_type == TFModelFormat.SavedModel:
            self.__extract_input_and_output_tensors_from_saved_model_v2()
        else:
            self.__extract_input_and_output_tensors_from_frozen_graph()

    def get_metadata(self) -> {str: List}:
        # We need to strip the trailing ":0" from the input names since DLR cannot handle it.
        # RelayTVM handles it gracefully and we can pass the names as is to relay.
        # https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/frontend/tensorflow.py#L2860
        return {
            "Inputs": [
                {'name': tensor.name.replace(":0", ""), 'dtype': tensor.dtype.name, 'shape': tensor.shape.as_list()}
                for tensor in self.input_tensors
            ],
            "Outputs": [
                {'name': tensor.name, 'dtype': tensor.dtype.name, 'shape': tensor.shape.as_list() if tensor.shape else None}
                for tensor in self.output_tensors
            ]
        }

    def get_tf_graph_from_graph_model_v2(self):
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

        tags = self.__get_tag_set()
        loaded = tf.compat.v2.saved_model.load(self.model_path.as_posix(), tags=tags)
        self.__model_tf_version = loaded.tensorflow_version
        for shape in self.__data_shape.values():
            tensor_spec = tf.TensorSpec(tuple(shape))
            break
        if len(loaded.signatures) == 0:
            f = loaded.__call__.get_concrete_function(tensor_spec)
        elif 'serving_default' in loaded.signatures:
            f = loaded.signatures['serving_default']
        else:
            f = loaded.signatures[list(loaded.signatures.keys())[0]]
        frozen_func = convert_variables_to_constants_v2(f, lower_control_flow=True)
        tf_graph = frozen_func.graph.as_graph_def(add_shapes=True)
        return tf_graph

    def get_tensorflow_version(self) -> str:
        if self.model_type == TFModelFormat.FrozenGraphModel:
            # Frozen graph creation is deprecated in TF 2.x.
            return "1.x"
        return self.__model_tf_version