/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2016 by Contributors * \file nnvm/symbolic.h * \brief Symbolic graph construction API * * This API is optional, but useful to allow user * to construct NNVM Graph easily, and quickly create * front-end host languages. */ #ifndef NNVM_SYMBOLIC_H_ #define NNVM_SYMBOLIC_H_ #include #include #include #include #include #include "base.h" #include "node.h" namespace nnvm { /*! * \brief Symbol is help class used to represent the operator node in Graph. * * Symbol acts as an interface for building graphs from different components * like Variable, Functor and Group. Symbol is also exported to python front-end * (while Graph is not) to enable quick test and deployment. Conceptually, * symbol is the final operation of a graph and thus including all the information * required (the graph) to evaluate its output value. */ class NNVM_DLL Symbol { public: /*! \brief option passed to ListAttr */ enum ListAttrOption { /*! \brief recursively list all attributes */ kRecursive = 0, /*! \brief only list attributes in current node */ kShallow = 1 }; /*! \brief option passed to ListInputNames */ enum ListInputOption { /*! \brief list all the arguments */ kAll = 0, /*! \brief list only read only arguments */ kReadOnlyArgs = 1, /*! * \brief List auxiliary states that can be mutated by the graph. * This excludes the ReadOnly arguments */ kAuxiliaryStates = 2 }; /*! \brief output entries contained in the symbol */ std::vector outputs; /*! * \brief Copy the symbol. * \return A deep copy of this symbol. */ Symbol Copy() const; /*! * \brief Print the symbol info to output stream. * \param os The output stream to print to. */ void Print(std::ostream &os) const; // NOLINT(*) /*! * \brief Get the index-th element from the returned tuple. * \param index Index of multi output. * \return The symbol corresponds to the indexed element. */ Symbol operator[] (size_t index) const; /*! * \brief List the input variable nodes. * * The order of the returned list is the same as the order of the input list to `operator()`. * * \param option The options to list the arguments. * \return The arguments list of this symbol, they can be either named or unnamed (empty string). * \sa ListInputOption */ std::vector ListInputs(ListInputOption option) const; /*! * \brief List the input names. * * The order of the returned list is the same as the order of the input list to `operator()`. * * \param option The options to list the arguments. * \return The arguments list of this symbol, they can be either named or unnamed (empty string). * \sa ListInputOption */ std::vector ListInputNames(ListInputOption option) const; /*! * \brief List the names of outputs for this symbol. * * For normal operators, it is usually symbol node name + "_output". * * \return get the descriptions of outputs for this symbol. */ std::vector ListOutputNames() const; /*! * \brief Compose the symbol with arguments, this changes the current symbol. * The kwargs passed in can be in-complete, * * The rest of the symbols will remain the same name. * * \param args Positional arguments. * \param kwargs Keyword arguments for the symbol. * \param name Name of returned symbol. */ void Compose(const array_view& args, const std::unordered_map& kwargs, const std::string& name); /*! * \brief Apply the symbol as a function, compose with arguments * * This is equivalent to Copy then Compose. * * \param args Positional arguments for the symbol. * \param kwargs Keyword arguments for the symbol. * \param name Name of returned symbol. * \return A new Symbol which is the composition of current symbol with its arguments. */ Symbol operator () (const array_view& args, const std::unordered_map& kwargs, const std::string& name) const; /*! * \brief Add control flow dependencies to the operators in symbols. * * For grouped symbol, an error will be raised. This mutates current symbolic Node. * * \param src The symbols to depend on. */ void AddControlDeps(const Symbol& src); /* * \brief Get all the internal nodes of the symbol. * \return symbol A new symbol whose output contains all the outputs of the symbols * including input variables and intermediate outputs. */ Symbol GetInternals() const; /* * \brief Get the direct inputs of the head node(s) of this symbol. * \return symbol A new symbol whose output contains all the inputs of the head * node(s). */ Symbol GetChildren() const; /*! * \brief Set additional attributes to current node. * * This only works for symbol with outputs from single operators. * For grouped symbol, an error will be raised. * * This function mutates the node's symbol and is not recommended. * * \param attrs The attributes to set. */ void SetAttrs(const std::vector >& attrs); /*! * \brief Get attributes from the symbol. * * This only works for symbol with outputs from single operators. * For grouped symbol, an error will be raised. * * \param key Key of the attribute. When key == "name", it returns the name attirbute. * \param out The output value of the attribute. * \return true If the attribute exists, false if the attribute does not exist. */ bool GetAttr(const std::string& key, std::string* out) const; /*! * \brief Get attribute dictionary from the symbol. * * For grouped symbol, an error will be raised. * * \param option If recursive flag is set, the attributes of all children are retrieved. * The name of symbol will be pre-pended to each key. * \return The created attribute. */ std::unordered_map ListAttrs(ListAttrOption option) const; /*! * \brief Get attribute dictionary from the symbol and all children. * * For grouped symbol, an error will be raised. * * \return The created attribute in format . */ std::vector > ListAttrsRecursive() const; /*! * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. * \param op The operator. * \param attrs The additional attributes. * \return Symbol that can be used to call compose further. */ static Symbol CreateFunctor(const Op* op, std::unordered_map attrs); /*! * \brief Create symbolic functor(AtomicSymbol) by given node attributes. * \param attrs pre-initialized Node attributes. * \return Symbol that can be used to call compose further. */ static Symbol CreateFunctor(const NodeAttrs& attrs); /*! * \brief Create symbol node representing variable. * \param name Name of the variable. * \return The symbol. */ static Symbol CreateVariable(const std::string& name); /*! * \brief Create equivalence of symbol by grouping the symbols together. * \param symbols A list of symbols to be grouped. * \return The grouped symbol. */ static Symbol CreateGroup(const std::vector& symbols); }; } // namespace nnvm #endif // NNVM_SYMBOLIC_H_