# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, too-many-locals, fixme
# pylint: disable=too-many-branches, too-many-statements
# pylint: disable=too-many-arguments
# pylint: disable=dangerous-default-value
"""Visualization module"""
from __future__ import absolute_import

import re
import copy
import json
import warnings
from .symbol import Symbol

def _str2tuple(string):
    """Convert shape string to list, internal use only.

    Parameters
    ----------
    string: str
        Shape string.

    Returns
    -------
    list of str
        Represents shape.
    """
    return re.findall(r"\d+", string)

def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74, 1.]):
    """Convert symbol for detail information.

    Parameters
    ----------
    symbol: Symbol
        Symbol to be visualized.
    shape: dict
        A dict of shapes, str->shape (tuple), given input shapes.
    line_length: int
        Rotal length of printed lines
    positions: list
        Relative or absolute positions of log elements in each line.
    Returns
    ------
    None
    """
    if not isinstance(symbol, Symbol):
        raise TypeError("symbol must be Symbol")
    show_shape = False
    if shape is not None:
        show_shape = True
        interals = symbol.get_internals()
        _, out_shapes, _ = interals.infer_shape(**shape)
        if out_shapes is None:
            raise ValueError("Input shape is incomplete")
        shape_dict = dict(zip(interals.list_outputs(), out_shapes))
    conf = json.loads(symbol.tojson())
    nodes = conf["nodes"]
    heads = set(conf["heads"][0])
    if positions[-1] <= 1:
        positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Previous Layer']
    def print_row(fields, positions):
        """Print format row.

        Parameters
        ----------
        fields: list
            Information field.
        positions: list
            Field length ratio.
        Returns
        ------
        None
        """
        line = ''
        for i, field in enumerate(fields):
            line += str(field)
            line = line[:positions[i]]
            line += ' ' * (positions[i] - len(line))
        print(line)
    print('_' * line_length)
    print_row(to_display, positions)
    print('=' * line_length)
    def print_layer_summary(node, out_shape):
        """print layer information

        Parameters
        ----------
        node: dict
            Node information.
        out_shape: dict
            Node shape information.
        Returns
        ------
            Node total parameters.
        """
        op = node["op"]
        pre_node = []
        pre_filter = 0
        if op != "null":
            inputs = node["inputs"]
            for item in inputs:
                input_node = nodes[item[0]]
                input_name = input_node["name"]
                if input_node["op"] != "null" or item[0] in heads:
                    # add precede
                    pre_node.append(input_name)
                    if show_shape:
                        if input_node["op"] != "null":
                            key = input_name + "_output"
                        else:
                            key = input_name
                        if key in shape_dict:
                            shape = shape_dict[key][1:]
                            pre_filter = pre_filter + int(shape[0])
        cur_param = 0
        if op == 'Convolution':
            if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
                num_group = int(node['attrs'].get('num_group', '1'))
                cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
                   // num_group
                for k in _str2tuple(node["attrs"]["kernel"]):
                    cur_param *= int(k)
            else:
                num_group = int(node['attrs'].get('num_group', '1'))
                cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
                   // num_group
                for k in _str2tuple(node["attrs"]["kernel"]):
                    cur_param *= int(k)
                cur_param += int(node["attrs"]["num_filter"])
        elif op == 'FullyConnected':
            if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
                cur_param = pre_filter * int(node["attrs"]["num_hidden"])
            else:
                cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"])
        elif op == 'BatchNorm':
            key = node["name"] + "_output"
            if show_shape:
                num_filter = shape_dict[key][1]
                cur_param = int(num_filter) * 2
        elif op == 'Embedding':
            cur_param = int(node["attrs"]['input_dim']) * int(node["attrs"]['output_dim'])
        if not pre_node:
            first_connection = ''
        else:
            first_connection = pre_node[0]
        fields = [node['name'] + '(' + op + ')',
                  "x".join([str(x) for x in out_shape]),
                  cur_param,
                  first_connection]
        print_row(fields, positions)
        if len(pre_node) > 1:
            for i in range(1, len(pre_node)):
                fields = ['', '', '', pre_node[i]]
                print_row(fields, positions)
        return cur_param
    total_params = 0
    for i, node in enumerate(nodes):
        out_shape = []
        op = node["op"]
        if op == "null" and i > 0:
            continue
        if op != "null" or i in heads:
            if show_shape:
                if op != "null":
                    key = node["name"] + "_output"
                else:
                    key = node["name"]
                if key in shape_dict:
                    out_shape = shape_dict[key][1:]
        total_params += print_layer_summary(nodes[i], out_shape)
        if i == len(nodes) - 1:
            print('=' * line_length)
        else:
            print('_' * line_length)
    print('Total params: %s' % total_params)
    print('_' * line_length)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
                 hide_weights=True):
    """Creates a visualization (Graphviz digraph object) of the given computation graph.
    Graphviz must be installed for this function to work.

    Parameters
    ----------
    title: str, optional
        Title of the generated visualization.
    symbol: Symbol
        A symbol from the computation graph. The generated digraph will visualize the part
        of the computation graph required to compute `symbol`.
    shape: dict, optional
        Specifies the shape of the input tensors. If specified, the visualization will include
        the shape of the tensors between the nodes. `shape` is a dictionary mapping
        input symbol names (str) to the corresponding tensor shape (tuple).
    node_attrs: dict, optional
        Specifies the attributes for nodes in the generated visualization. `node_attrs` is
        a dictionary of Graphviz attribute names and values. For example,
            ``node_attrs={"shape":"oval","fixedsize":"false"}``
            will use oval shape for nodes and allow variable sized nodes in the visualization.
    hide_weights: bool, optional
        If True (default), then inputs with names of form *_weight (corresponding to weight
        tensors) or *_bias (corresponding to bias vectors) will be hidden for a cleaner
        visualization.

    Returns
    -------
    dot: Digraph
        A Graphviz digraph object visualizing the computation graph to compute `symbol`.

    Example
    -------
    >>> net = mx.sym.Variable('data')
    >>> net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)
    >>> net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
    >>> net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)
    >>> net = mx.sym.SoftmaxOutput(data=net, name='out')
    >>> digraph = mx.viz.plot_network(net, shape={'data':(100,200)},
    ... node_attrs={"fixedsize":"false"})
    >>> digraph.view()
    """
    # todo add shape support
    try:
        from graphviz import Digraph
    except:
        raise ImportError("Draw network requires graphviz library")
    if not isinstance(symbol, Symbol):
        raise TypeError("symbol must be a Symbol")
    draw_shape = False
    if shape is not None:
        draw_shape = True
        interals = symbol.get_internals()
        _, out_shapes, _ = interals.infer_shape(**shape)
        if out_shapes is None:
            raise ValueError("Input shape is incomplete")
        shape_dict = dict(zip(interals.list_outputs(), out_shapes))
    conf = json.loads(symbol.tojson())
    nodes = conf["nodes"]
    # check if multiple nodes have the same name
    if len(nodes) != len(set([node["name"] for node in nodes])):
        seen_nodes = set()
        # find all repeated names
        repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes
                       or seen_nodes.add(node['name']))
        warning_message = "There are multiple variables with the same name in your graph, " \
                          "this may result in cyclic graph. Repeated names: " + ','.join(repeated)
        warnings.warn(warning_message, RuntimeWarning)
    # default attributes of node
    node_attr = {"shape": "box", "fixedsize": "true",
                 "width": "1.3", "height": "0.8034", "style": "filled"}
    # merge the dict provided by user and the default one
    node_attr.update(node_attrs)
    dot = Digraph(name=title, format=save_format)
    # color map
    cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
          "#fdb462", "#b3de69", "#fccde5")

    def looks_like_weight(name):
        """Internal helper to figure out if node should be hidden with `hide_weights`.
        """
        weight_like = ('_weight', '_bias', '_beta', '_gamma',
                       '_moving_var', '_moving_mean', '_running_var', '_running_mean')
        return name.endswith(weight_like)

    # make nodes
    hidden_nodes = set()
    for node in nodes:
        op = node["op"]
        name = node["name"]
        # input data
        attr = copy.deepcopy(node_attr)
        label = name

        if op == "null":
            if looks_like_weight(node["name"]):
                if hide_weights:
                    hidden_nodes.add(node["name"])
                # else we don't render a node, but
                # don't add it to the hidden_nodes set
                # so it gets rendered as an empty oval
                continue
            attr["shape"] = "oval" # inputs get their own shape
            label = node["name"]
            attr["fillcolor"] = cm[0]
        elif op == "Convolution":
            label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])),
                                                 "x".join(_str2tuple(node["attrs"]["stride"]))
                                                 if "stride" in node["attrs"] else "1",
                                                 node["attrs"]["num_filter"])
            attr["fillcolor"] = cm[1]
        elif op == "FullyConnected":
            label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"]
            attr["fillcolor"] = cm[1]
        elif op == "BatchNorm":
            attr["fillcolor"] = cm[3]
        elif op in ('Activation', 'LeakyReLU'):
            label = r"%s\n%s" % (op, node["attrs"]["act_type"])
            attr["fillcolor"] = cm[2]
        elif op == "Pooling":
            label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"],
                                             "x".join(_str2tuple(node["attrs"]["kernel"])),
                                             "x".join(_str2tuple(node["attrs"]["stride"]))
                                             if "stride" in node["attrs"] else "1")
            attr["fillcolor"] = cm[4]
        elif op in ("Concat", "Flatten", "Reshape"):
            attr["fillcolor"] = cm[5]
        elif op == "Softmax":
            attr["fillcolor"] = cm[6]
        else:
            attr["fillcolor"] = cm[7]
            if op == "Custom":
                label = node["attrs"]["op_type"]

        dot.node(name=name, label=label, **attr)

    # add edges
    for node in nodes:          # pylint: disable=too-many-nested-blocks
        op = node["op"]
        name = node["name"]
        if op == "null":
            continue
        else:
            inputs = node["inputs"]
            for item in inputs:
                input_node = nodes[item[0]]
                input_name = input_node["name"]
                if input_name not in hidden_nodes:
                    attr = {"dir": "back", 'arrowtail':'open'}
                    # add shapes
                    if draw_shape:
                        if input_node["op"] != "null":
                            key = input_name + "_output"
                            if "attrs" in input_node:
                                params = input_node["attrs"]
                                if "num_outputs" in params:
                                    key += str(int(params["num_outputs"]) - 1)
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                        else:
                            key = input_name
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                    dot.edge(tail_name=name, head_name=input_name, **attr)

    return dot