/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2016 by Contributors * \file nnvm/pass.h * \brief Pass that can be applied to a graph. */ #ifndef NNVM_PASS_H_ #define NNVM_PASS_H_ #include #include #include "base.h" #include "graph.h" namespace nnvm { /*! * \brief A PassFunction is an "Operator on Graph". * It takes a source graph and return a graph that may or may * not be the same as the input one. * * A pass function can either change the graph structure (thus, * generating a new Graph), or add new attributes to the graph. * * \param src The graph to be transformed. * \return The generated graph. */ typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on the input graph. * \param src The graph to be transformed. * \param passes A list of pass names to be applied. * \return The transformed graph */ Graph ApplyPasses(Graph src, const std::vector& passes); /*! * \brief Apply one pass to the graph. * \param src The graph to be transformed. * \param pass The name of pass to be applied. * \return The transformed graph. */ inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); } /*! * \brief Registry entry for pass functions. */ struct PassFunctionReg : public dmlc::FunctionRegEntryBase { /*! * \brief Whether the pass will change graph structure * If this is false, the pass will only change attributes. */ bool change_graph{false}; /*! \brief dependencies on operator attributes */ std::vector op_attr_dependency; /*! \brief dependencies on attributes in the graph */ std::vector graph_attr_dependency; /*! \brief generated targets of graph attributes */ std::vector graph_attr_targets; /*! * \brief Set whether this pass will change graph structure. * \param v If true, the pass will change graph structure. * \return Reference to self. */ PassFunctionReg& set_change_graph(bool v) { // NOLINT(*) change_graph = v; return *this; } /*! * \brief Declare that this pass will generate the given graph attribute name * once it is applied on the graph. * \param attr_name Name of the graph attribute. * \return Reference to self. */ PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*) graph_attr_targets.push_back(attr_name); return *this; } /*! * \brief Declare this pass requires the given operator attribute to be * available before being applied on the graph. * \param attr_name Name of the attribute. * \return Reference to self. */ PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*) op_attr_dependency.push_back(attr_name); return *this; } /*! * \brief Declare this pass requires the given graph attribute to be * available before being applied on the graph. * \param attr_name Name of the attribute. * \return Reference to self. */ PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*) graph_attr_dependency.push_back(attr_name); return *this; } }; /*! * \def NNVM_REGISTER_PASS * \brief Macro to register pass fuctions. * * \code * // example of registering a shape inference pass * NNVM_REGISTER_PASS(InferShape) * .describe("Shape Inference function, generate graph attributes") * .provide_graph_attr("data_shape") * .depend_graph_attr("indexed_graph") * .depend_op_attr("infer_shape") * .set_body([](const Graph& g) { * // shape inference logic * }); * \endcode */ #define NNVM_REGISTER_PASS(name) \ DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) } // namespace nnvm #endif // NNVM_PASS_H_