/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \file combine_parallel_op.h * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ #ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ #include #include #include #include #include #include #include #include #include #include #include "./expr_subst.h" #include "pattern_utils.h" namespace tvm { namespace relay { using Branch = std::vector; using Group = std::vector; using FIsSupportedOp = std::function; using FAreCompatibleOps = std::function; using ExprSubstMap = std::unordered_map; /* * Class to find parallel branches starting with op that are * grouped if they are able to be combined. They are eligible to * be combined if they have the same input data. * Op can be followed by zero or more elemwise or broadcast ops, * which are included in the group. * Intermediate nodes have exactly one successor. It is possible that branches meet at a point, * which should be handled in ParallelOpCombiner. * * data * / \ * op op * | | * elem-wise elem-wise * | | */ class BranchGroupFinder : private ExprVisitor { public: /* * \brief Constructor * \param op The op that indicates the start of each group * \param fis_supported_op function that returns true if op * is supported for combining * \param fare_compatible_ops function that returns true if * two ops are compatible for combining */ BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); /* * \brief Finds all groups that can be combined. * \param expr Relay expression that represents function * to look at for groups to be combined * \return Vector of groups which can be combined. */ std::vector Find(const Expr& expr); private: /* \brief Cache the op for finding parallel branches */ const Op& cached_op_; /* \brief function to return true if op is eligible to be combined, * false otherwise */ FIsSupportedOp fis_supported_op_; /* \brief function to return true if two parallel ops are eligible * to be combined, false otherwise */ FAreCompatibleOps fare_compatible_ops_; /* \brief ops that are on the first (logically, leftmost) branch * of parallel ops and are eligible to be combined */ std::unordered_set op_roots_; /* \brief map of Expr to CallNodes that follow it */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> children_map_; /* * \brief Creates new branch from op and its children that have * elementwise or broadcast patterns * \return New branch */ Branch CreateBranch(const CallNode* op); /* * \brief Expression visitor function */ void VisitExpr_(const CallNode* n) final; }; /* * Abstract class to find and combine parallel ops and the elementwise ops that follow. */ class ParallelOpCombiner { public: /*! \brief virtual destructor */ virtual ~ParallelOpCombiner() {} /* * \brief Constructor. * \param op_name name of op to combine * \param min_num_branches min number of parallel branches beginning with op * to start combining */ explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); /* * \brief Combines ops and following elementwise or broadcast ops * \param expr function to modify * \return new function with combined ops */ Expr Combine(const Expr& expr); protected: /* * \brief Checks if node is supported to be combined * \param n node in question * \return True if the op represented by n is supported to be the root of a branch * to be combined. False otherwise. */ virtual bool IsSupportedOp(const CallNode* n) = 0; /* * \brief Checks if two ops can be combined * \param a node a * \param b node b * \return True if a and b can be combined. False otherwise. */ virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; /* * \brief Makes combined op from parallel ops in branches. This usually involves * concatenating or stacking inputs, then creating a new call. * \param branches branches that are to be combined * \return new call with branches combined. */ virtual Call MakeCombinedOp(const Group& branches) = 0; /* * \brief Checks if argument of op following combined ops are able to be combined * \param a node a * \param b node b * \param index index of argument in question * \return True if argument of a and b and index can be combined */ virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; /* * \brief Create combined call from ops that follow the initial combined op at the depth-th level. * This usually involves concatenating or stacking inputs, then creating a new call. * Only called if IsArgCompatbile returns true for each arg. * \param data combined op * \param branches branches of parallel ops to be combined * \param depth depth at which to combine ops * \param parent_index index of arg that corresponds to original input that was shared among * all combined ops * \return new combined call */ virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) = 0; /* * \brief Updates map of expr to substitute with combined expr. This usually involves * slicing or splitting data. * \param data combined op * \param branches branches of parallel ops to be combined * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) = 0; private: /* \brief Cache the op to be combined */ const Op& cached_op_; /* \brief minimum number of parallel branches to combine */ uint64_t min_num_branches_; /* \brief map of Expr to Expr to substitute it with after running pass */ ExprSubstMap subst_map_; /* * \brief Combine parallel branches and updates subst_map_ with Exprs * to be substituted * \param branches branches to be combined */ void CombineBranches(const Group& branches); /* * \brief Combine parallel branches and updates subst_map_ with Exprs * to be substituted * \param branches parallel branches to potentially be combined * \param depth depth at which to look at op * \param parent_index index of arg that corresponds to original input that was shared among * all combined ops * \return true if parallel ops at depth can be combined, false otherwise */ bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); }; } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_