# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable=invalid-name, too-many-locals, fixme # pylint: disable=too-many-branches, too-many-statements # pylint: disable=too-many-arguments # pylint: disable=dangerous-default-value """Visualization module""" from __future__ import absolute_import import re import copy import json import warnings from .symbol import Symbol def _str2tuple(string): """Convert shape string to list, internal use only. Parameters ---------- string: str Shape string. Returns ------- list of str Represents shape. """ return re.findall(r"\d+", string) def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74, 1.]): """Convert symbol for detail information. Parameters ---------- symbol: Symbol Symbol to be visualized. shape: dict A dict of shapes, str->shape (tuple), given input shapes. line_length: int Rotal length of printed lines positions: list Relative or absolute positions of log elements in each line. Returns ------ None """ if not isinstance(symbol, Symbol): raise TypeError("symbol must be Symbol") show_shape = False if shape is not None: show_shape = True interals = symbol.get_internals() _, out_shapes, _ = interals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") shape_dict = dict(zip(interals.list_outputs(), out_shapes)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] heads = set(conf["heads"][0]) if positions[-1] <= 1: positions = [int(line_length * p) for p in positions] # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Previous Layer'] def print_row(fields, positions): """Print format row. Parameters ---------- fields: list Information field. positions: list Field length ratio. Returns ------ None """ line = '' for i, field in enumerate(fields): line += str(field) line = line[:positions[i]] line += ' ' * (positions[i] - len(line)) print(line) print('_' * line_length) print_row(to_display, positions) print('=' * line_length) def print_layer_summary(node, out_shape): """print layer information Parameters ---------- node: dict Node information. out_shape: dict Node shape information. Returns ------ Node total parameters. """ op = node["op"] pre_node = [] pre_filter = 0 if op != "null": inputs = node["inputs"] for item in inputs: input_node = nodes[item[0]] input_name = input_node["name"] if input_node["op"] != "null" or item[0] in heads: # add precede pre_node.append(input_name) if show_shape: if input_node["op"] != "null": key = input_name + "_output" else: key = input_name if key in shape_dict: shape = shape_dict[key][1:] pre_filter = pre_filter + int(shape[0]) cur_param = 0 if op == 'Convolution': if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True': num_group = int(node['attrs'].get('num_group', '1')) cur_param = pre_filter * int(node["attrs"]["num_filter"]) \ // num_group for k in _str2tuple(node["attrs"]["kernel"]): cur_param *= int(k) else: num_group = int(node['attrs'].get('num_group', '1')) cur_param = pre_filter * int(node["attrs"]["num_filter"]) \ // num_group for k in _str2tuple(node["attrs"]["kernel"]): cur_param *= int(k) cur_param += int(node["attrs"]["num_filter"]) elif op == 'FullyConnected': if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True': cur_param = pre_filter * int(node["attrs"]["num_hidden"]) else: cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"]) elif op == 'BatchNorm': key = node["name"] + "_output" if show_shape: num_filter = shape_dict[key][1] cur_param = int(num_filter) * 2 elif op == 'Embedding': cur_param = int(node["attrs"]['input_dim']) * int(node["attrs"]['output_dim']) if not pre_node: first_connection = '' else: first_connection = pre_node[0] fields = [node['name'] + '(' + op + ')', "x".join([str(x) for x in out_shape]), cur_param, first_connection] print_row(fields, positions) if len(pre_node) > 1: for i in range(1, len(pre_node)): fields = ['', '', '', pre_node[i]] print_row(fields, positions) return cur_param total_params = 0 for i, node in enumerate(nodes): out_shape = [] op = node["op"] if op == "null" and i > 0: continue if op != "null" or i in heads: if show_shape: if op != "null": key = node["name"] + "_output" else: key = node["name"] if key in shape_dict: out_shape = shape_dict[key][1:] total_params += print_layer_summary(nodes[i], out_shape) if i == len(nodes) - 1: print('=' * line_length) else: print('_' * line_length) print('Total params: %s' % total_params) print('_' * line_length) def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={}, hide_weights=True): """Creates a visualization (Graphviz digraph object) of the given computation graph. Graphviz must be installed for this function to work. Parameters ---------- title: str, optional Title of the generated visualization. symbol: Symbol A symbol from the computation graph. The generated digraph will visualize the part of the computation graph required to compute `symbol`. shape: dict, optional Specifies the shape of the input tensors. If specified, the visualization will include the shape of the tensors between the nodes. `shape` is a dictionary mapping input symbol names (str) to the corresponding tensor shape (tuple). node_attrs: dict, optional Specifies the attributes for nodes in the generated visualization. `node_attrs` is a dictionary of Graphviz attribute names and values. For example, ``node_attrs={"shape":"oval","fixedsize":"false"}`` will use oval shape for nodes and allow variable sized nodes in the visualization. hide_weights: bool, optional If True (default), then inputs with names of form *_weight (corresponding to weight tensors) or *_bias (corresponding to bias vectors) will be hidden for a cleaner visualization. Returns ------- dot: Digraph A Graphviz digraph object visualizing the computation graph to compute `symbol`. Example ------- >>> net = mx.sym.Variable('data') >>> net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128) >>> net = mx.sym.Activation(data=net, name='relu1', act_type="relu") >>> net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10) >>> net = mx.sym.SoftmaxOutput(data=net, name='out') >>> digraph = mx.viz.plot_network(net, shape={'data':(100,200)}, ... node_attrs={"fixedsize":"false"}) >>> digraph.view() """ # todo add shape support try: from graphviz import Digraph except: raise ImportError("Draw network requires graphviz library") if not isinstance(symbol, Symbol): raise TypeError("symbol must be a Symbol") draw_shape = False if shape is not None: draw_shape = True interals = symbol.get_internals() _, out_shapes, _ = interals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") shape_dict = dict(zip(interals.list_outputs(), out_shapes)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] # check if multiple nodes have the same name if len(nodes) != len(set([node["name"] for node in nodes])): seen_nodes = set() # find all repeated names repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes or seen_nodes.add(node['name'])) warning_message = "There are multiple variables with the same name in your graph, " \ "this may result in cyclic graph. Repeated names: " + ','.join(repeated) warnings.warn(warning_message, RuntimeWarning) # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} # merge the dict provided by user and the default one node_attr.update(node_attrs) dot = Digraph(name=title, format=save_format) # color map cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3", "#fdb462", "#b3de69", "#fccde5") def looks_like_weight(name): """Internal helper to figure out if node should be hidden with `hide_weights`. """ weight_like = ('_weight', '_bias', '_beta', '_gamma', '_moving_var', '_moving_mean', '_running_var', '_running_mean') return name.endswith(weight_like) # make nodes hidden_nodes = set() for node in nodes: op = node["op"] name = node["name"] # input data attr = copy.deepcopy(node_attr) label = name if op == "null": if looks_like_weight(node["name"]): if hide_weights: hidden_nodes.add(node["name"]) # else we don't render a node, but # don't add it to the hidden_nodes set # so it gets rendered as an empty oval continue attr["shape"] = "oval" # inputs get their own shape label = node["name"] attr["fillcolor"] = cm[0] elif op == "Convolution": label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])), "x".join(_str2tuple(node["attrs"]["stride"])) if "stride" in node["attrs"] else "1", node["attrs"]["num_filter"]) attr["fillcolor"] = cm[1] elif op == "FullyConnected": label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"] attr["fillcolor"] = cm[1] elif op == "BatchNorm": attr["fillcolor"] = cm[3] elif op in ('Activation', 'LeakyReLU'): label = r"%s\n%s" % (op, node["attrs"]["act_type"]) attr["fillcolor"] = cm[2] elif op == "Pooling": label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"], "x".join(_str2tuple(node["attrs"]["kernel"])), "x".join(_str2tuple(node["attrs"]["stride"])) if "stride" in node["attrs"] else "1") attr["fillcolor"] = cm[4] elif op in ("Concat", "Flatten", "Reshape"): attr["fillcolor"] = cm[5] elif op == "Softmax": attr["fillcolor"] = cm[6] else: attr["fillcolor"] = cm[7] if op == "Custom": label = node["attrs"]["op_type"] dot.node(name=name, label=label, **attr) # add edges for node in nodes: # pylint: disable=too-many-nested-blocks op = node["op"] name = node["name"] if op == "null": continue else: inputs = node["inputs"] for item in inputs: input_node = nodes[item[0]] input_name = input_node["name"] if input_name not in hidden_nodes: attr = {"dir": "back", 'arrowtail':'open'} # add shapes if draw_shape: if input_node["op"] != "null": key = input_name + "_output" if "attrs" in input_node: params = input_node["attrs"] if "num_outputs" in params: key += str(int(params["num_outputs"]) - 1) shape = shape_dict[key][1:] label = "x".join([str(x) for x in shape]) attr["label"] = label else: key = input_name shape = shape_dict[key][1:] label = "x".join([str(x) for x in shape]) attr["label"] = label dot.edge(tail_name=name, head_name=input_name, **attr) return dot