/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2016 by Contributors * \file nnvm/node.h * \brief Graph node data structure. */ #ifndef NNVM_NODE_H_ #define NNVM_NODE_H_ #include #include #include #include #include #include "base.h" #include "op.h" #include "c_api.h" namespace nnvm { // Forward declare node. class Node; class Symbol; /*! * \brief we always used NodePtr for a reference pointer * to the node, so this alias can be changed in case. * * By default, NodePtr is a std::shared_ptr of node */ using NodePtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { NodeEntry(NodePtr node, uint32_t index, uint32_t version): node(std::move(node)), index(index), version(version) {} explicit NodeEntry(NodePtr node): node(std::move(node)), index(), version() {} /** * MXNet assumes that a node with a null ptr doesn't have a gradient attached. Don't change this * constructor. */ NodeEntry(): node(nullptr), index(), version() {} /*! \brief the source node of this data */ NodePtr node; /*! \brief index of output from the source. */ uint32_t index; /*! * \brief version of input Variable. * This field can only be nonzero when this->node is a Variable node. * version is increased by one each time a Variable get composed to a mutation Op. * This information can be helpful to decide order of operations when sequence of mutation happens. */ uint32_t version; }; /*! * \brief This lets you use a NodeEntry as a key in a unordered_map of the form * unordered_map */ struct NodeEntryHash { size_t operator()(const NodeEntry& e) const { return std::hash()(e.node.get()) ^ (std::hash()(e.index) << 1 >> 1) ^ (std::hash()(e.version) << 1); } }; /*! * \brief This lets you use a NodeEntry as a key in a unordered_map of the form * unordered_map */ struct NodeEntryEqual { size_t operator()(const NodeEntry& a, const NodeEntry& b) const { return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version); } }; /*! use NodeEntry as key in unordered_map */ template using NodeEntryMap = std::unordered_map; /*! * \brief The attributes of the current operation node. * Usually are additional parameters like axis, */ struct NodeAttrs { /*! * \brief The operator this node uses. * For place holder variable, op == nullptr. */ const Op *op{nullptr}; /*! \brief name of the node */ std::string name; /*! \brief The dictionary representation of attributes */ std::unordered_map dict; /*! * \brief A parsed version of attributes, * This is generated if OpProperty.attr_parser is registered. * The object can be used to quickly access attributes. */ any parsed; /*! * \brief Some operators take graphs as input. These operators include * control flow operators and high-order functions. * These graphs don't change when the operators are invoked for different * mini-batches. In this sense, the subgraphs are kind of similar to * the parameters and show be kept as node attributes. * * Users need to make sure the subgraphs are disjoint with the main graph. * If a graph shares nodes with subgraphs, loading the graph from LoadJSON * may generate a graph that has a different structure from the original graph * (some of the nodes are duplicated). If nodes are shared between two graphs, * shared nodes might be executed multiple times, which can be a problem for * stateful operators. */ std::vector > subgraphs; }; /*! * \brief Node represents an operation in a computation graph. */ class NNVM_DLL Node { public: Node() = default; Node(const Op* op, const std::string& name) { this->attrs.op = op; this->attrs.name = name; } /*! \brief The attributes in the node. */ NodeAttrs attrs; /*! \brief inputs to this node */ std::vector inputs; /*! * \brief Optional control flow dependencies * Gives operation must be performed before this operation. */ std::vector control_deps; /*! \brief additional fields for this node */ any info; /*! \brief destructor of node */ ~Node(); /*! \return operator in this node */ inline const Op* op() const; /*! * \brief return whether node is placeholder variable. * This is equivalent to op == nullptr * \return whether node is placeholder input variable */ inline bool is_variable() const; /*! \return number of outputs from this node */ inline uint32_t num_outputs() const; /*! \return number of inputs from this node */ inline uint32_t num_inputs() const; /*! * \brief create a new empty shared_ptr of Node. * \return a created empty node. */ template static NodePtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } }; /*! * \brief Quick utilities make node. * \param op_name The name of operator * \param node_name The name of the node * \param inputs The input entries * \param attrs The attributes * \return The created node entry. */ inline NodeEntry MakeNode( const char* op_name, std::string node_name, std::vector inputs, std::unordered_map attrs = std::unordered_map()) { NodePtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); p->attrs.dict = attrs; if (p->attrs.op->attr_parser) { p->attrs.op->attr_parser(&(p->attrs)); } p->inputs = std::move(inputs); return NodeEntry(p, 0, 0); } // implementation of functions. inline const Op* Node::op() const { return this->attrs.op; } inline bool Node::is_variable() const { return this->op() == nullptr; } inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; if (this->op()->get_num_outputs == nullptr) { return this->op()->num_outputs; } else { return this->op()->get_num_outputs(this->attrs); } } inline uint32_t Node::num_inputs() const { if (is_variable()) return 1; if (this->op()->get_num_inputs == nullptr) { return this->op()->num_inputs; } else { return this->op()->get_num_inputs(this->attrs); } } } // namespace nnvm #endif // NNVM_NODE_H_