/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file op_utils.h * \brief Common utility used in operator construction. */ #ifndef TVM_TE_OPERATION_OP_UTILS_H_ #define TVM_TE_OPERATION_OP_UTILS_H_ #include #include #include #include #include #include "../../tir/transforms/arg_binder.h" #include "../../tir/transforms/ir_utils.h" #include "../schedule/message_passing.h" namespace tvm { namespace te { using tir::MergeNest; /*! * \brief Build loop nest for stage. * * \param stage The stage to create a loop nest. * \param dom_map The range of each iter var. * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop. * \param new_loop_var Whether create new loop variable. * \param skip_iter Whether skip certain iteration. * \param p_value_map The result value of each IterVar. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 */ std::vector > MakeLoopNest(const Stage& stage, const std::unordered_map& dom_map, size_t begin_iter_pos, bool new_loop_var, const std::unordered_set& skip_iter, std::unordered_map* p_value_map, bool debug_keep_trivial_loop); /*! * \brief Create a nest of if checking the predicates. * * \param predicates The predicates to be checked. * \return List of If nest that checks the predicates. */ std::vector MakeIfNest(const std::vector& predicates); /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param stmt The statement to be processed. * \param replace The replacement rule. */ Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace); /*! * \brief Substitute the variables of stmt by value map. * \param stmt the statment * \param value_map The value map. * \return Substituted result. */ Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); /*! * \brief Substitute the variables of primExpr by value map. * \param expr the expression to be processed. * \param value_map The value map. * \return Substituted result. */ PrimExpr Substitute(PrimExpr expr, const std::unordered_map& value_map); /*! * \brief Converts Halide ForKind to its corresponding IterVarType * \param kind The ForKind to be converted */ IterVarType ForKindToIterVarType(tir::ForKind kind); /*! * \brief Converts IterVarType to its corresponding Halide ForKind * \param iter_type The IterVarType to be converted */ tir::ForKind IterVarTypeToForKind(IterVarType iter_type); } // namespace te } // namespace tvm #endif // TVM_TE_OPERATION_OP_UTILS_H_