/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \brief Helper utilities to implement compute_op. * \file compute_op.h */ #ifndef TVM_TE_OPERATION_COMPUTE_OP_H_ #define TVM_TE_OPERATION_COMPUTE_OP_H_ #include #include #include #include namespace tvm { namespace te { // loop nest structure for general compute // This the loop nest structured used in compute. // Does not include the loop body. struct ComputeLoopNest { // The common number of loops between init and main size_t num_common_loop; // predicates for the initialize loop std::vector init_predicates; // Initialization nest involved. std::vector > init_nest; // Value map for the init code std::unordered_map init_vmap; // Predicates for the main update loop std::vector main_predicates; // The general loop nest std::vector > main_nest; // Value map for the IterVar. std::unordered_map main_vmap; /*! * \brief constructor to build ComputeOpNest * \param self The pointer to compute op. * \param stage The scxhedule stage. * \param dom_map The domain map. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); }; /*! * \brief Build body of compute for cross thread reduction pattern. * \param self The pointer to ComputeOpNode * \param stage The schedule stage. * \param dom_map The domain map. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); /*! * \brief Build body of compute for tensorization. * \param self The pointer to ComputeOpNode * \param stage The schedule stage. * \param dom_map The domain map. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); /*! * \brief Transform the update part when there is no init func in tensorizing * \param stage The stage for tensorizing. * \param dom_map The range of each iter var. * \param n The loop nest structured used in compute. * \param body The body func in tensorize intrin * \param update The update func in tensorize intrin * \return Transformed result. */ Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, const ComputeLoopNest& n, Stmt body, Stmt update); } // namespace te } // namespace tvm #endif // TVM_TE_OPERATION_COMPUTE_OP_H_