/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "../utils.h" namespace tvm { namespace tir { using support::NDIntSet; /******** Error Classes ********/ /*! * \brief An error raised when not all required blocks are under the given loop. * \tparam is_consumer Indicates if all the required blocks are consumers or producers */ template class NotAllRequiredBlocksAreVisitedError : public ScheduleError { public: explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, const Array& required) : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); required_.push_back(GetRef(block)); } } String FastErrorString() const final { return "ScheduleError: Not all required blocks are under the loop scope"; } String DetailRenderTemplate() const final { String relation = is_consumer ? "consumer(s)" : "producer(s)"; std::ostringstream os; os << "The primitive requires all the " << relation << " of the given block to be present under the target loop. However, there are " << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the " << relation << ":"; for (int i = 0, n = required_.size(); i < n; ++i) { os << "{" << i << "}"; } return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {required_.begin(), required_.end()}; } private: IRModule mod_; int num_not_visited_; Array required_; }; /*! * \brief An error raised when the given block is not in the same block scope as the given loop, * or the given loop is the ancestor of the given block. */ class NotInSameScopeError : public ScheduleError { public: static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, arith::Analyzer* analyzer) { for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { if (const ForNode* loop = p->StmtAs()) { analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } else if (p != scope_root_sref.get()) { throw NotInSameScopeError(self->mod, block_sref, loop_sref); } else { break; } } for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) { if (p == loop_sref.get()) { throw NotInSameScopeError(self->mod, block_sref, loop_sref); } } } String FastErrorString() const final { return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " "not to be the ancestor of block"; } String DetailRenderTemplate() const final { return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " "and loop not to be the ancestor of block"; } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {block_, loop_}; } private: explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) : mod_(mod), block_(GetRef(block_sref->StmtAs())), loop_(GetRef(loop_sref->StmtAs())) {} IRModule mod_; Block block_; For loop_; }; /******** Helper Functions/Classes ********/ /*! * \brief Find a point where the block can be inserted under the loop * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop * \param self The schedule state * \param subtrees The subtrees under the loop, among which the insertion points are sought * \param producer_srefs The producer blocks * \param consumer_srefs The consumer blocks * \param block2realize A cache that maps a block to its realize * \return The last position the new block can be inserted onto, and the * producer-consumer-relationship is still satisfied. * \throws ScheduleError if there is no such insertion point found */ template int FindInsertionPoint( const ScheduleState& self, const Array& subtrees, const Array& producer_srefs, const Array& consumer_srefs, std::unordered_map* block2realize) { ProducerConsumerSplit split = ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); // Step 1. Check if all the producers are visited in the subtrees, if required to if (require_all_producers_visited) { int num_producers = producer_srefs.size(); if (split.n_producers_visited < num_producers) { throw NotAllRequiredBlocksAreVisitedError( self->mod, num_producers - split.n_producers_visited, producer_srefs); } } // Step 2. Check if all the consumers are visited in the subtrees, if required to if (require_all_consumers_visited) { int num_consumers = consumer_srefs.size(); if (split.n_consumers_visited < num_consumers) { throw NotAllRequiredBlocksAreVisitedError( self->mod, num_consumers - split.n_consumers_visited, consumer_srefs); } } // Step 3. Check if there is at least one index of the position can be inserted into // The valid indices are: (last_producer_position, first_consumer_position] ICHECK(split.last_producer_position < split.first_consumer_position); // Step 4. Return the last valid insertion point return split.first_consumer_position; } /*! * \brief A helper to reconstruct the block scope where the given block is moved under the given * loop, and the given block's induced loop nest is regenerated to satisfy the required region. */ class ScopeReconstructor : private StmtMutator { public: explicit ScopeReconstructor(Block scope_root, Block block, For loop) : scope_root_(scope_root), block_(block), loop_(loop) {} using StmtMutator::operator(); /*! * \brief Create the loop nest on top of the block, induced by the given block var's domain * \param insert_position The position among the subtrees where the block and its induced loop * nest is inserted * \param iter_doms The domain of each block var * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1 */ void MakeNewLoop(int insert_position, std::vector iter_doms, bool preserve_unit_loops) { int n_iters = iter_doms.size(); Array loop_vars; Array loop_extents; Array iter_values; loop_vars.reserve(n_iters); loop_extents.reserve(n_iters); iter_values.reserve(n_iters); for (int i = 0; i < n_iters; ++i) { const Range& iter_dom = iter_doms[i]; if (preserve_unit_loops || !is_one(iter_dom->extent)) { Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); loop_vars.push_back(var); loop_extents.push_back(iter_dom->extent); iter_values.push_back(iter_dom->min + var); } else { iter_values.push_back(iter_dom->min); } } this->new_block_realize_ = BlockRealize(std::move(iter_values), const_true(), std::move(block_)); Stmt new_subtree = this->new_block_realize_; for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { const Var& loop_var = loop_vars[i]; const PrimExpr& loop_extent = loop_extents[i]; new_subtree = For(/*loop_var=*/loop_var, /*min=*/Integer(0), /*extent=*/loop_extent, /*ForKind=*/ForKind::kSerial, /*body=*/std::move(new_subtree)); } Array subtrees = AsArray(loop_->body); subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree)); ObjectPtr new_loop = make_object(*loop_.get()); new_loop->body = SeqStmt(std::move(subtrees)); this->new_loop_ = For(std::move(new_loop)); } private: Stmt VisitStmt_(const BlockNode* block) final { if (block != scope_root_.get()) { return GetRef(block); } if (block == rm_src_stmt_.get()) { block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode); } return StmtMutator::VisitStmt_(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == rm_src_stmt_.get()) { loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode); } if (loop == loop_.get()) { return new_loop_; } return StmtMutator::VisitStmt_(loop); } public: /*! \brief The root block of the block scope */ Block scope_root_; /*! \brief The given block to be moved */ Block block_; /*! \brief The given loop the block and its loop nest to be put under */ For loop_; /*! \brief The new loop to replace the original loop */ For new_loop_{nullptr}; /*! \brief The new block realize to the moved block */ BlockRealize new_block_realize_{nullptr}; /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ Stmt rm_src_stmt_{nullptr}; /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ Stmt rm_tgt_stmt_{nullptr}; }; /*! * \brief Calculate a list of accessed buffer regions under a path of loops * \tparam relax_storage_scope Whether to relax beyond the path according to the storage and * execution scope * \param binding The block binding, used to unbind the buffer regions * \param buffer_regions The buffer regions to be calculated * \param relax_path_low_inclusive The lowest point in the loop path, inclusive * \param relax_path_high_exclusive The highest point in the loop path, exclusive * \param relaxed Where the calculation result is stored */ template void RelaxBufferRegions(const Map& binding, const Array& buffer_regions, const StmtSRef& relax_path_low_inclusive, const StmtSRef& relax_path_high_exclusive, std::unordered_map>* relaxed) { runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""}; // We cache the variable domains runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; Optional> var_dom = NullOpt; // Enumerate every buffer region for (const BufferRegion& buffer_region : buffer_regions) { const Buffer& buffer = buffer_region->buffer; const Array& region = buffer_region->region; // Skip the buffer regions we are not interested in auto it = relaxed->find(buffer.get()); if (it == relaxed->end()) { continue; } std::vector& relaxed_regions = it->second; // Check and update the cached `var_dom` runtime::StorageScope scope = relax_storage_scope ? runtime::StorageScope::Create(buffer.scope()) : global_scope; runtime::StorageRank rank = scope.rank; if (rank != previous_rank || !var_dom.defined()) { previous_rank = rank; var_dom = AsIntSet(LoopDomainOfSRefTreePath( /*low_inclusive=*/relax_path_low_inclusive, /*high_exclusive=*/relax_path_high_exclusive, /*extra_relax_scope=*/scope)); } // Relax the region Array relaxed_region = arith::EvalSet(Substitute(region, binding), var_dom.value()); relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()}); } } /*! * \brief Calculate the iteration domain of a provided integer set to fully cover the required * domain * \param provided The provided integer set to cover the required domain * \param required The required domain to be covered * \param iter_doms The result iteration domains to be updated * \param analyzer The arithmetic analyzer */ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, std::unordered_map>* iter_doms, arith::Analyzer* analyzer) { PrimExpr provided_min = analyzer->Simplify(provided.min()); PrimExpr provided_extent = analyzer->Simplify(provided.max() - provided_min + 1); PrimExpr required_min = analyzer->Simplify(required.min()); PrimExpr required_extent = analyzer->Simplify(required.max() - required_min + 1); PrimExpr dom_min{nullptr}, dom_extent{nullptr}; Var dom_var{ObjectPtr{nullptr}}; arith::PVar p_v; arith::PVar p_e; if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) { PrimExpr e = p_e.Eval(); dom_var = p_v.Eval(); dom_min = floordiv(required_min, e); dom_extent = analyzer->Simplify((required_extent + e - 1) / e); } else if (analyzer->CanProveEqual(provided_extent, 1) && p_v.Match(provided_min)) { dom_var = p_v.Eval(); dom_min = required_min; dom_extent = required_extent; } else { ICHECK(false) << "ValueError: BufferRegion pattern match failed"; } auto it = iter_doms->find(dom_var.get()); if (it != iter_doms->end()) { std::vector& doms = it->second; doms.push_back(arith::IntSet::FromMinExtent(dom_min, dom_extent)); } else { ICHECK(analyzer->CanProveEqual(provided_min, required_min)); ICHECK(analyzer->CanProveEqual(provided_extent, required_extent)); } } /*! * \brief Calculate the domain of block vars to cover the required region * \param iter_vars The list of block vars to cover the required region * \param provided_regions The region provided by one iteration instance of the block vars * \param required_regions The region required to be covered * \param analyzer The arithmetic analyzer * \return A list of iteration domain corresponding to the given list of block vars */ std::vector CalculateBlockVarDomain( const Array& iter_vars, std::unordered_map> provided_regions, std::unordered_map> required_regions, arith::Analyzer* analyzer) { int n_iters = iter_vars.size(); // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty) std::unordered_map> iter_doms; iter_doms.reserve(n_iters); for (const IterVar& iter_var : iter_vars) { iter_doms[iter_var->var.get()] = {}; } // Step 2. For each buffer, update the domain according to the provided and required regions for (const auto& kv : provided_regions) { const BufferNode* buffer = kv.first; const std::vector& many_provided_regions = kv.second; // Calculate `provided_region` and `required_region` auto it = required_regions.find(buffer); if (it == required_regions.end() || it->second.empty()) { continue; } NDIntSet required_region = support::NDIntSetUnion(it->second); NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions); ICHECK_EQ(provided_region.size(), buffer->shape.size()); ICHECK_EQ(required_region.size(), buffer->shape.size()); // For each dimension, update the iteration domain int ndim = buffer->shape.size(); for (int i = 0; i < ndim; ++i) { arith::IntSet provided = provided_region[i]; arith::IntSet required = required_region[i]; required = arith::Intersect( {std::move(required), arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i])}); UpdateBlockVarDomain(provided, required, &iter_doms, analyzer); } } // Union the iter var domains, put them in the same order of block vars, and return std::vector result; result.reserve(n_iters); for (const IterVar& iter_var : iter_vars) { const std::vector& doms = iter_doms.at(iter_var->var.get()); arith::IntSet dom = arith::IntSet::FromRange(iter_var->dom); if (!doms.empty()) { dom = arith::Intersect({std::move(dom), arith::Union(doms)}); } PrimExpr min = analyzer->Simplify(dom.min()); PrimExpr extent = analyzer->Simplify(dom.max() - min + 1); result.push_back(Range::FromMinExtent(min, extent)); } return result; } /*! * \brief Calculate the provided region of the given block by one single of its execution instance, * as well as the required buffer regions relaxed to the given loop * \tparam is_compute_at Indicates if the operation is compute-at or reverse-compute-at * \param block The given block that provides buffer regions * \param loop_sref The given loop under which the block is going to be moved to * \param block2realize Maps a block to its corresponding BlockRealize * \param producer_srefs The producers of the given block * \param consumer_srefs The consumers of the given block * \param provided_regions The calculated regions provided by the block * \param required_regions The calculated regions required by its consumers (in compute-at) or * producers (in reverse-compute-at) */ template void CalculateProvidedRequiredRegions( const BlockNode* block, const StmtSRef& loop_sref, std::unordered_map block2realize, Array producer_srefs, Array consumer_srefs, std::unordered_map>* provided_regions, std::unordered_map>* required_regions) { // Step 1. Calculate the region provided by a single execution instance of `block` const Array& provided_buffers = is_compute_at ? block->writes : block->reads; provided_regions->reserve(provided_buffers.size()); required_regions->reserve(provided_buffers.size()); for (const BufferRegion& provided_buffer_region : provided_buffers) { const BufferNode* buffer = provided_buffer_region->buffer.get(); const Array& region = provided_buffer_region->region; (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region)); (*required_regions)[buffer].clear(); } // Step 2. Calculate the region required by dependent blocks under `loop` for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, /*relax_path_low_inclusive=*/GetRef(required_block_sref->parent), /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); } } /******** Main Implementation ********/ template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, arith::Analyzer* analyzer, bool check_only = false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks // Check condition 1) and 2): stage pipeline and subtree compact dataflow StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/true); Block scope_root = GetRef(scope_root_sref->StmtAs()); BlockScope scope = self->GetBlockScope(scope_root_sref); Array producer_srefs = GetProducers(block_sref, scope); Array consumer_srefs = GetConsumers(block_sref, scope); // Check condition 3): `block` and `loop` are under the same scope, // and `loop` is not the ancestor of `block` NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref, analyzer); // Check condition 4): `block` is not an output block if (is_compute_at) { CheckNotOutputBlock(self, block_sref, scope_root_sref); } // Step 2. Plan for the removal of `block` ScopeReconstructor reconstructor(scope_root, GetRef(block), GetRef(loop)); LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); // Step 3. Find the insertion point under `loop` // Check condition 5): all the required block are under the given loop std::unordered_map block2realize; block2realize.reserve(self->block_info.size()); int insert_position = FindInsertionPoint( /*self=*/self, /*subtrees=*/AsArray(loop->body), /*producer_srefs=*/producer_srefs, /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize); // Step 4. Calculate the region provided by a single execution instance of `block`, // as well as the region required by dependent blocks under `loop`. // Here is the definition of `provide` and `require`: // - In compute-at, `provide` means `produce`, and `require` means `consume` // - In reverse-compute-at, `provide` means `consume`, and `require` means `produce` std::unordered_map> provided_regions; std::unordered_map> required_regions; CalculateProvidedRequiredRegions( /*block=*/block, /*loop_sref=*/loop_sref, /*block2realize=*/std::move(block2realize), /*producer_srefs=*/std::move(producer_srefs), /*consumer_srefs=*/std::move(consumer_srefs), /*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions); // Step 5. Calculate the iteration domain for each block var std::vector iter_doms = CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars, /*provided_regions=*/std::move(provided_regions), /*required_regions=*/std::move(required_regions), /*analyzer=*/analyzer); // Step 6. Create the new scope according to the iteration domain reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*preserve_unit_loops=*/preserve_unit_loops); Block new_scope_root = Downcast(reconstructor(scope_root)); // Step 7. Do the actual replacement if (check_only) { return; } self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); // Step 8. Update the cached flags BlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/reconstructor.new_block_realize_, /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), /*analyzer=*/analyzer); } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer, true); } catch (const tvm::runtime::Error& e) { return false; } return true; } bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer, true); } catch (const tvm::runtime::Error& e) { return false; } return true; } /******** InstructionKind Registration ********/ struct ComputeAtTraits : public UnpackedInstTraits { static constexpr const char* kName = "ComputeAt"; static constexpr bool kIsPure = false; private: static constexpr size_t kNumInputs = 2; static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, Bool preserve_unit_loops) { return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); } static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, Bool preserve_unit_loops) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); return py.Str(); } template friend struct ::tvm::tir::UnpackedInstTraits; }; struct ReverseComputeAtTraits : public UnpackedInstTraits { static constexpr const char* kName = "ReverseComputeAt"; static constexpr bool kIsPure = false; private: static constexpr size_t kNumInputs = 2; static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, Bool preserve_unit_loops) { return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); } static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, Bool preserve_unit_loops) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); return py.Str(); } template friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits); TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits); } // namespace tir } // namespace tvm