/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "./utils.h" namespace tvm { namespace tir { /******** Annotation ********/ Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { Map annotations = block->annotations; annotations.Set(attr_key, attr_value); ObjectPtr new_block = make_object(*block); new_block->annotations = std::move(annotations); return Block(new_block); } /******** Buffer Related ********/ Buffer WithScope(const Buffer& buffer, const String& scope) { ObjectPtr new_buffer = make_object(*buffer.get()); ObjectPtr new_var = make_object(*buffer->data.get()); const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, scope); new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); new_buffer->name = buffer->name + "_" + scope; return Buffer(new_buffer); } Array ReplaceBuffer(Array regions, const Buffer& source, const Buffer& target) { regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { if (region->buffer.same_as(source)) { ObjectPtr n = make_object(*region.get()); n->buffer = target; return BufferRegion(n); } return region; }); return regions; } Array ReplaceBuffer(Array match_buffers, const Buffer& source, const Buffer& target) { match_buffers.MutateByApply([&source, &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { if (match_buffer->source->buffer.same_as(source)) { ObjectPtr n = make_object(*match_buffer.get()); n->source = BufferRegion(target, n->source->region); return MatchBufferRegion(n); } return match_buffer; }); return match_buffers; } /******** Block Removal ********/ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt) { class OnlyLeafError : public ScheduleError { public: explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} String FastErrorString() const final { return "ScheduleError: Cannot remove the only leaf in the scope"; } String DetailRenderTemplate() const final { return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " "scope will be empty."; } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } IRModule mod_; Block leaf_block_; Block scope_root_; }; // Go upwards until find an ancestor with more than one child const StmtNode* last_stmt = leaf_block_sref->stmt; StmtSRefNode* sref = leaf_block_sref->parent; for (;; last_stmt = sref->stmt, sref = sref->parent) { if (const auto* loop = sref->StmtAs()) { if (const auto* seq = loop->body.as()) { if (seq->size() > 1) { break; } } } else { // Removal is not done beyond scope-level. // When encountering a block, i.e. the scope root, we simply stop break; } } if (const auto* block = sref->StmtAs()) { if (const auto* seq = block->body.as()) { ObjectPtr n = make_object(*block); n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); *src_stmt = GetRef(block); *tgt_stmt = Stmt(std::move(n)); return; } } if (const auto* loop = sref->StmtAs()) { if (const auto* seq = loop->body.as()) { ObjectPtr n = make_object(*loop); n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); *src_stmt = GetRef(loop); *tgt_stmt = Stmt(std::move(n)); return; } } ICHECK(sref != nullptr && sref->stmt != nullptr); const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block, leaf_block_sref); const auto* scope_block = TVM_SREF_TO_BLOCK(scope_block, sref); throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } } // namespace tir } // namespace tvm