# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable=invalid-name,too-many-locals,no-self-use """ Support import export formats.""" from __future__ import absolute_import as _abs from .... import symbol from .... import ndarray as nd from ....base import string_types from ._import_helper import _convert_map as convert_map class GraphProto(object): # pylint: disable=too-few-public-methods """A helper class for handling mxnet symbol copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto """ def __init__(self): self._nodes = {} self._params = {} self._num_input = 0 self._num_param = 0 self.aux_dict = {} self.arg_dict = {} self.model_metadata = {} def _convert_operator(self, node_name, op_name, attrs, inputs): """Convert from onnx operator to mxnet operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters ---------- :param node_name : str name of the node to be translated. :param op_name : str Operator name, such as Convolution, FullyConnected :param attrs : dict Dict of operator attributes :param inputs: list list of inputs to the operator Returns ------- :return mxnet_sym Converted mxnet symbol """ if op_name in convert_map: op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) if isinstance(op_name, string_types): new_op = getattr(symbol, op_name, None) if not new_op: raise RuntimeError("Unable to map op_name {} to sym".format(op_name)) if node_name is None: mxnet_sym = new_op(*inputs, **new_attrs) else: mxnet_sym = new_op(name=node_name, *inputs, **new_attrs) return mxnet_sym return op_name def from_onnx(self, graph): """Construct symbol from onnx graph. Parameters ---------- graph : onnx protobuf object The loaded onnx graph Returns ------- sym :symbol.Symbol The returned mxnet symbol params : dict A dict of name: nd.array pairs, used as pretrained weights """ #get input, output shapes self.model_metadata = self.get_graph_metadata(graph) # parse network inputs, aka parameters for init_tensor in graph.initializer: if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) # converting GraphProto message for i in graph.input: if i.name in self._params: # i is a param instead of input self._nodes[i.name] = symbol.Variable(name=i.name, shape=self._params[i.name].shape) else: self._nodes[i.name] = symbol.Variable(name=i.name) # constructing nodes, nodes are stored as directed acyclic graph # converting NodeProto message for node in graph.node: op_name = node.op_type node_name = node.name.strip() node_name = node_name if node_name else None onnx_attr = self._parse_attr(node.attribute) inputs = [self._nodes[i] for i in node.input] mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs) for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))): self._nodes[k] = mxnet_sym[i] # splitting params into args and aux params for args in mxnet_sym.list_arguments(): if args in self._params: self.arg_dict.update({args: nd.array(self._params[args])}) for aux in mxnet_sym.list_auxiliary_states(): if aux in self._params: self.aux_dict.update({aux: nd.array(self._params[aux])}) # now return the outputs out = [self._nodes[i.name] for i in graph.output] if len(out) > 1: out = symbol.Group(out) else: out = out[0] return out, self.arg_dict, self.aux_dict def get_graph_metadata(self, graph): """ Get the model metadata from a given onnx graph. """ _params = set() for tensor_vals in graph.initializer: _params.add(tensor_vals.name) input_data = [] for graph_input in graph.input: if graph_input.name not in _params: shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim] input_data.append((graph_input.name, tuple(shape))) output_data = [] for graph_out in graph.output: shape = [val.dim_value for val in graph_out.type.tensor_type.shape.dim] output_data.append((graph_out.name, tuple(shape))) metadata = {'input_tensor_data' : input_data, 'output_tensor_data' : output_data } return metadata def graph_to_gluon(self, graph, ctx): """Construct SymbolBlock from onnx graph. Parameters ---------- graph : onnx protobuf object The loaded onnx graph ctx : Context or list of Context Loads the model into one or many context(s). Returns ------- sym_block :gluon.nn.SymbolBlock The returned gluon SymbolBlock """ sym, arg_params, aux_params = self.from_onnx(graph) metadata = self.get_graph_metadata(graph) data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']] data_inputs = [symbol.var(data_name) for data_name in data_names] from ....gluon import SymbolBlock net = SymbolBlock(outputs=sym, inputs=data_inputs) net_params = net.collect_params() for param in arg_params: if param in net_params: net_params[param].shape = arg_params[param].shape net_params[param]._load_init(arg_params[param], ctx=ctx) for param in aux_params: if param in net_params: net_params[param].shape = aux_params[param].shape net_params[param]._load_init(aux_params[param], ctx=ctx) return net def _parse_array(self, tensor_proto): """Grab data in TensorProto and convert to numpy array.""" try: from onnx.numpy_helper import to_array except ImportError: raise ImportError("Onnx and protobuf need to be installed. " + "Instructions to install - https://github.com/onnx/onnx") np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) return nd.array(np_array) def _parse_attr(self, attr_proto): """Convert a list of AttributeProto to a dict, with names as keys.""" attrs = {} for a in attr_proto: for f in ['f', 'i', 's']: if a.HasField(f): attrs[a.name] = getattr(a, f) # Needed for supporting python version > 3.5 if isinstance(attrs[a.name], bytes): attrs[a.name] = attrs[a.name].decode(encoding='utf-8') for f in ['floats', 'ints', 'strings']: if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) for f in ['t', 'g']: if a.HasField(f): attrs[a.name] = getattr(a, f) for f in ['tensors', 'graphs']: if list(getattr(a, f)): raise NotImplementedError("Filed {} is not supported in mxnet.".format(f)) if a.name not in attrs: raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) return attrs