/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2016 by Contributors * \file nnvm/op.h * \brief Operator information structor. */ #ifndef NNVM_OP_H_ #define NNVM_OP_H_ #include #include #include #include #include #include #include #include "base.h" #include "c_api.h" namespace nnvm { // forward declarations class Node; struct NodeAttrs; template class OpMap; class OpGroup; class OpRegistryEntry; using dmlc::ParamFieldInfo; /*! \brief constant to indicate it take any length of positional inputs */ static const uint32_t kVarg = std::numeric_limits::max(); /*! * \brief Operator structure. * * Besides the fields in the structure, * arbitary additional information can be associated with each op. * See function GetAttr for details. * * \code * // Example usage of Op * * // registeration of oeprators * // NOTE that the attr function can register any * // additional attributes to the operator * NNVM_REGISTER_OP(add) * .describe("add two inputs together") * .set_num_inputs(2) * .set_attr("OpKernel", AddKernel) * .include("ElementwiseOpAttr"); * * // can register attribute by group * // all the ops that include the group get the attribute. * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) * .set_attr("FInferShape", ElementwiseInferShape); * * NNVM_REGISTER_OP(sub) * .describe("substract one tensor from another") * .set_num_inputs(2); * * // Can call regster multiple times in different files * // to register different part of information * NNVM_REGISTER_OP(sub) * .set_attr("OpKernel", SubKernel); * .include("ElementwiseOpAttr"); * * // get operators from registry. * void my_function() { * const Op* add = Op::Get("add"); * const Op* sub = Op::Get("sub"); * // query basic information about each operator. * assert(op->name == "plus"); * assert(op->num_inputs == 2); * * // get additional registered information, * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. * const OpMap& kernel = Op::GetAttr("OpKernel"); * // we can get the kernel functions by using operator as key. * auto add_kernel = kernel[add]; * auto sub_kernel = kernel[sub]; * // subsequent code can make use of the queried kernel functions. * } * \endcode */ class NNVM_DLL Op { public: /*! \brief name of the operator */ std::string name; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ std::string description; /* \brief description of inputs and keyword arguments*/ std::vector arguments; /*! * \brief number of inputs to the operator, * -1 means it is variable length * When get_num_inputs is presented, * the number will be decided by get_num_inputs instead. * \sa get_num_inputs */ uint32_t num_inputs = 1; /*! * \brief number of outputs of the operator * When get_num_outputs is presented. * The number of outputs will be decided by * get_num_outputs function * \sa get_num_outputs */ uint32_t num_outputs = 1; /*! * \brief support level of the operator, * The lower the more priority it contains. * This is in analogies to BLAS levels. */ uint32_t support_level = 10; /*! * \brief get number of outputs given information about the node. * \param attrs The attribute of the node * \return number of outputs. */ std::function get_num_outputs = nullptr; /*! * \brief get number of inputs given information about the node. * \param attrs The attribute of the node * \return number of inputs */ std::function get_num_inputs = nullptr; /*! * \brief Attribute parser to parse the NodeAttrs information. * * This can help to get quick access to a parsed attribute * object * * \code * // Example usage of attr_parser. * * // Suppose we want to register operator sum. * // The parameters about sum operator * struct SumParam { * int axis; * }; * // The parser function * void SumAttrParser(NodeAttrs* attrs) { * // This will be invoked during node construction. * SumParam param; * // parse axis string to integer * param.axis = atoi(attrs->dict["axis"].c_str()); * // set the parsed parameter * attrs->parsed = std::move(param); * } * // The other function that can utilize the parsed result. * TShape SumInferShape(const NodeAttrs& attrs, * const std::vector& ishapes) { * // we can use the parsed version of param * // without repeatively parsing the parameter * const SumParam& param = nnvm::get(attrs.parsed); * } * \endcode */ std::function attr_parser = nullptr; // function fields. /*! * \brief setter function during registration * Set the description of operator * \param descr the description string. * \return reference to self. */ inline Op& describe(const std::string& descr); // NOLINT(*) /*! * \brief Add argument information to the function. * \param name Name of the argument. * \param type Type of the argument. * \param description Description of the argument. * \return reference to self. */ inline Op& add_argument(const std::string &name, const std::string &type, const std::string &description); /*! * \brief Append list if arguments to the end. * \param args Additional list of arguments. * \return reference to self. */ inline Op& add_arguments(const std::vector &args); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. * \return reference to self. */ inline Op& set_num_inputs(uint32_t n); // NOLINT(*) /*! * \brief Set the support level of op. * \param level The support level. * \return reference to self. */ inline Op& set_support_level(uint32_t level); // NOLINT(*) /*! * \brief Set the get_num_outputs function. * \param fn The function to be set. * \return reference to self. */ inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. * \return reference to self. */ inline Op& set_num_outputs(uint32_t n); // NOLINT(*) /*! * \brief Set the get_num_outputs function. * \param fn The function to be set. * \return reference to self. */ inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. * \return reference to self. */ inline Op& set_attr_parser(std::function fn); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. * \param value The value to be set. * \param plevel The priority level of this set, * an higher priority level attribute * will replace lower priority level attribute. * Must be bigger than 0. * * Cannot set with same plevel twice in the code. * * \tparam ValueType The type of the value to be set. */ template inline Op& set_attr(const std::string& attr_name, // NOLINT(*) const ValueType& value, int plevel = 10); /*! * \brief Add another alias to this operator. * The same Op can be queried with Op::Get(alias) * \param alias The alias of the operator. * \return reference to self. */ Op& add_alias(const std::string& alias); // NOLINT(*) /*! * \brief Include all the attributes from an registered op group. * \param group_name The name of the group. * \return reference to self. * * \sa NNVM_REGISTER_OP_GROUP */ Op& include(const std::string& group_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. * \param op_name Name of the operator. * \return Pointer to a Op, valid throughout program lifetime. */ static const Op* Get(const std::string& op_name); /*! * \brief Get additional registered attribute about operators. * If nothing has been registered, an empty OpMap will be returned. * \param attr_name The name of the attribute. * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ template static const OpMap& GetAttr(const std::string& attr_name); private: template friend class OpMap; friend class OpGroup; friend class dmlc::Registry; // Program internal unique index of operator. // Used to help index the program. uint32_t index_{0}; // internal constructor Op(); // get const reference to certain attribute static const any* GetAttrMap(const std::string& key); // update the attribute OpMap static void UpdateAttrMap(const std::string& key, std::function updater); // add a trigger based on tag matching on certain tag attribute // This will apply trigger on all the op such that // include the corresponding group. // The trigger will also be applied to all future registrations // that calls include static void AddGroupTrigger(const std::string& group_name, std::function trigger); }; /*! * \brief A map data structure that takes Op* as key * and returns ValueType * \tparam ValueType The type of the value stored in map. */ template class OpMap { public: /*! * \brief get the corresponding value element at op * \param op The key to the map * \return the const reference to the content value. */ inline const ValueType& operator[](const Op* op) const; /*! * \brief get the corresponding value element at op with default value. * \param op The key to the map * \param def_value The default value when the key does not exist. * \return the const reference to the content value. */ inline const ValueType& get(const Op* op, const ValueType& def_value) const; /*! * \brief Check if the map has op as key. * \param op The key to the map * \return 1 if op is contained in map, 0 otherwise. */ inline int count(const Op* op) const; /*! * \brief Check if the map has op as key. * \param op The key to the map * \return true if op is contained in map, false otherwise. */ inline bool contains(const Op* op) const; private: friend class Op; // internal attribute name std::string attr_name_; // internal data std::vector > data_; OpMap() = default; }; /*! * \brief auxiliary data structure used to * set attributes to a group of operators */ class OpGroup { public: /*! \brief the tag key to be matched */ std::string group_name; /*! * \brief Register additional attributes to operator group. * \param attr_name The name of the attribute. * \param value The value to be set. * \param plevel The priority level of this set, * an higher priority level attribute * will replace lower priority level attribute. * Must be bigger than 0. * * Cannot set with same plevel twice in the code. * * \tparam ValueType The type of the value to be set. */ template inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) const ValueType& value, int plevel = 1); }; // internal macros to make #define NNVM_REGISTER_VAR_DEF(OpName) \ static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName #define NNVM_REGISTER_GVAR_DEF(TagName) \ static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName /*! * \def NNVM_REGISTER_OP * \brief Register a new operator, or set attribute of the corresponding op. * * \param OpName The name of registry * * \code * * NNVM_REGISTER_OP(add) * .describe("add two inputs together") * .set_num_inputs(2) * .set_attr("gpu_kernel", AddKernel); * * \endcode */ #define NNVM_REGISTER_OP(OpName) \ DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) /*! * \def NNVM_REGISTER_OP_GROUP * \brief Register attribute to a group of operators. * These attributes will be registered to Op that include the group. * * \param GroupName The name of the group. * * \code * * NNVM_REGISTER_OP(add) * .include("ElementwiseOpAttr"); * * // register same attributes to all the ops that include the group * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) * .set_attr("FInferShape", ElementwiseInferShape); * * NNVM_REGISTER_OP(mul) * .include("ElementwiseOpAttr"); * * \endcode */ #define NNVM_REGISTER_OP_GROUP(GroupName) \ DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ ::nnvm::OpGroup {#GroupName} // implementations of template functions after this. // member function of Op template inline const OpMap& Op::GetAttr(const std::string& key) { const any* ref = GetAttrMap(key); if (ref == nullptr) { // update the attribute map of the key by creating new empty OpMap UpdateAttrMap(key, [key](any* pmap) { // use callback so it is in lockscope if (pmap->empty()) { OpMap pm; pm.attr_name_ = key; *pmap = std::move(pm); } }); ref = GetAttrMap(key); } return nnvm::get >(*ref); } template inline Op& Op::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; // update the attribute map of the key by creating new empty if needed. UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) { // the callback is in lockscope so is threadsafe. if (pmap->empty()) { OpMap pm; pm.attr_name_ = attr_name; *pmap = std::move(pm); } CHECK(pmap->type() == typeid(OpMap)) << "Attribute " << attr_name << " of operator " << this->name << " is registered as inconsistent types" << " previously " << pmap->type().name() << " current " << typeid(OpMap).name(); std::vector >& vec = nnvm::get >(*pmap).data_; // resize the value type. if (vec.size() <= index_) { vec.resize(index_ + 1, std::make_pair(ValueType(), 0)); } std::pair& p = vec[index_]; CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name << " is already registered with same plevel=" << plevel; if (p.second < plevel) { vec[index_] = std::make_pair(value, plevel); } }); return *this; } inline Op& Op::describe(const std::string& descr) { // NOLINT(*) this->description = descr; return *this; } inline Op& Op::add_argument(const std::string &name, const std::string &type, const std::string &description) { arguments.push_back({name, type, type, description}); return *this; } inline Op& Op::add_arguments(const std::vector &args) { this->arguments.insert(arguments.end(), args.begin(), args.end()); return *this; } inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) this->num_inputs = n; return *this; } inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) this->support_level = n; return *this; } inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) this->num_outputs = n; return *this; } inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) this->attr_parser = fn; return *this; } // member functions of OpMap template inline int OpMap::count(const Op* op) const { if (contains(op)) { return 1; } else { return 0; } } template inline bool OpMap::contains(const Op* op) const { if (op == nullptr) { return false; } const uint32_t idx = op->index_; return idx < data_.size() ? (data_[idx].second != 0) : false; } template inline const ValueType& OpMap::operator[](const Op* op) const { CHECK(op != nullptr); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second) << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } template inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { if (op == nullptr) return def_value; const uint32_t idx = op->index_; if (idx < data_.size() && data_[idx].second) { return data_[idx].first; } else { return def_value; } } template inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value, int plevel) { auto trigger = [attr_name, value, plevel](Op* op) { op->set_attr(attr_name, value, plevel); }; Op::AddGroupTrigger(group_name, trigger); return *this; } } // namespace nnvm #endif // NNVM_OP_H_