/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file stmt_functor.cc */ #include #include #include #include #include #include "./functor_common.h" namespace tvm { namespace tir { void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const WhileNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const AllocateNode* op) { VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); this->VisitExpr(op->condition); } void StmtVisitor::VisitStmt_(const StoreNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->index); this->VisitExpr(op->predicate); } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { this->VisitExpr(op->value); VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); }); this->VisitExpr(op->condition); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); if (op->else_case.defined()) { this->VisitStmt(op->else_case); } } void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitExpr(op->condition); this->VisitExpr(op->message); this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); } void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); }); this->VisitStmt(op->body); this->VisitExpr(op->condition); } void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); }); } void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } void StmtVisitor::VisitStmt_(const BlockNode* op) { auto fvisit_buffer_region = [this](const BufferRegion& s) { for (const auto& range : s->region) { this->VisitExpr(range->min); this->VisitExpr(range->extent); } }; VisitArray(op->iter_vars, [this](const IterVar& iter_var) { this->VisitExpr(iter_var->dom->min); this->VisitExpr(iter_var->dom->extent); }); VisitArray(op->reads, fvisit_buffer_region); VisitArray(op->writes, fvisit_buffer_region); VisitArray(op->match_buffers, [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { fvisit_buffer_region(match_buffer_region->source); }); if (op->init.defined()) { this->VisitStmt(op->init.value()); } this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->predicate); this->VisitStmt(op->block); } class StmtMutator::Internal { public: /*! * \brief Mutate array's element by fmutate function. * * \note Use extra care for copy on write setting. * * In particular, consider the following case of two reference chains: * - strongref0 -> loop0 -> loop1 -> loop2 * - strongref1 -> loop3 -> loop1 -> loop2 * * Think of the case of calling MutateArray on loop1->loop2(as const reference). * When both strongref0 and strongref1 exists, the context does not allow copy * on write, even though loop1 uniquely refers to loop2. * * \param self The pointer to the mutator. * \param arr Array to be mutated, const reference is used to allow copy on write * mutation in a recursive visitor. * \param fmutate The mutator function. * \return The mutated array, a new copy can be created. */ template static Array MutateArray(StmtMutator* self, const Array& arr, F fmutate) { if (self->allow_copy_on_write_ && arr.unique()) { // if we allow copy on write, we can directly // call the inplace mutate function. const_cast&>(arr).MutateByApply(fmutate); return arr; } else { bool allow_cow = false; Array copy = arr; std::swap(allow_cow, self->allow_copy_on_write_); copy.MutateByApply(fmutate); std::swap(allow_cow, self->allow_copy_on_write_); return copy; } } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const IterVar& iter_var) { PrimExpr min = self->VisitExpr(iter_var->dom->min); PrimExpr extent = self->VisitExpr(iter_var->dom->extent); if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) { return iter_var; } else { return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type, iter_var->thread_tag); } }; return MutateArray(self, arr, fmutate); } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; return MutateArray(self, arr, fmutate); } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; return MutateArray(self, arr, fmutate); } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const Range& r) { PrimExpr min = self->VisitExpr(r->min); PrimExpr extent = self->VisitExpr(r->extent); if (min.same_as(r->min) && extent.same_as(r->extent)) { return r; } else { return Range::FromMinExtent(min, extent); } }; return MutateArray(self, arr, fmutate); } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const BufferRegion& buffer_region) { Array region = Mutate(self, buffer_region->region); if (region.same_as(buffer_region->region)) { return buffer_region; } else { return BufferRegion(buffer_region->buffer, region); } }; return MutateArray(self, arr, fmutate); } static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { Array region = Mutate(self, match_buffer_region->source->region); if (region.same_as(match_buffer_region->source->region)) { return match_buffer_region; } else { return MatchBufferRegion(match_buffer_region->buffer, BufferRegion(match_buffer_region->source->buffer, region)); } }; return MutateArray(self, arr, fmutate); } }; Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); n->extent = std::move(extent); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const WhileNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); n->body = std::move(body); n->condition = std::move(condition); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); Stmt else_case; if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); n->then_case = std::move(then_case); n->else_case = std::move(else_case); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const StoreNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr index = this->VisitExpr(op->index); PrimExpr predicate = this->VisitExpr(op->predicate); if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); n->index = std::move(index); n->predicate = std::move(predicate); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); if (value.same_as(op->value) && indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); n->indices = std::move(indices); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->bounds = std::move(bounds); n->condition = std::move(condition); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { Array indices = Internal::Mutate(this, op->indices); PrimExpr value = this->VisitExpr(op->value); if (indices.same_as(op->indices) && value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->indices = std::move(indices); n->value = std::move(value); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->bounds = std::move(bounds); n->body = std::move(body); n->condition = std::move(condition); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { Region bounds = Internal::Mutate(this, op->bounds); if (bounds.same_as(op->bounds)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->bounds = std::move(bounds); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->seq = std::move(seq); return Stmt(n); } } // advanced visit function for seqstmt. Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate) { if (flatten_before_visit) { // Pass 1, check if we need to flatten. bool need_flatten = false; for (size_t i = 0; i < op->seq.size(); ++i) { Stmt tmp = (*op)[i]; if (tmp.as()) need_flatten = true; } flatten_before_visit = need_flatten; } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { Array seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->seq = std::move(seq); return Stmt(n); } }; if (flatten_before_visit) { Array seq; SeqStmt::Flattener flattener(&seq); flattener(0, op->seq); // NOTE: If copy on write is allowed // the assignment to seq below will // destruct the original seq. // // Such destruction removes duplicated reference // count to children and still enables COW for // child Stmt. ObjectPtr n = CopyOnWrite(op); n->seq = std::move(seq); return frunvisit(n.operator->()); } else { return frunvisit(op); } } Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); n->message = std::move(message); n->body = std::move(body); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const BlockNode* op) { Array iter_vars = Internal::Mutate(this, op->iter_vars); Array reads = Internal::Mutate(this, op->reads); Array writes = Internal::Mutate(this, op->writes); Array match_buffers = Internal::Mutate(this, op->match_buffers); Optional init = NullOpt; if (op->init.defined()) { init = VisitStmt(op->init.value()); } Stmt body = VisitStmt(op->body); if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && body.same_as(op->body) && init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_vars = std::move(iter_vars); n->reads = std::move(reads); n->writes = std::move(writes); n->body = std::move(body); n->init = std::move(init); n->match_buffers = std::move(match_buffers); return Stmt(n); } } Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { Array v = Internal::Mutate(this, op->iter_values); PrimExpr pred = this->VisitExpr(op->predicate); Stmt block = this->VisitStmt(op->block); if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_values = std::move(v); n->predicate = std::move(pred); n->block = Downcast(block); return Stmt(n); } } // Implementations of IRTransform, PostOrderVisit and Substitute class IRApplyVisit : public StmtExprVisitor { public: explicit IRApplyVisit(std::function f) : f_(f) {} void VisitExpr(const PrimExpr& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); ExprVisitor::VisitExpr(node); f_(node); } void VisitStmt(const Stmt& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); StmtVisitor::VisitStmt(node); f_(node); } private: std::function f_; std::unordered_set visited_; }; void PostOrderVisit(const ObjectRef& node, std::function fvisit) { if (node.as()) { IRApplyVisit visitor(fvisit); visitor(Downcast(node)); } else { IRApplyVisit visitor(fvisit); visitor(Downcast(node)); } } class IRTransformer final : public StmtExprMutator { public: IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, const std::unordered_set& only_enable) : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} Stmt VisitStmt(const Stmt& stmt) final { return MutateInternal(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); } PrimExpr VisitExpr(const PrimExpr& expr) final { return MutateInternal(expr, [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); } private: // NOTE: redirect to parent's call // This is used to get around limitation of gcc-4.8 Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } template T MutateInternal(const T& node, F fmutate) { if (only_enable_.size() && !only_enable_.count(node->type_index())) { return fmutate(node); } if (f_preorder_ != nullptr) { T pre = f_preorder_(node); if (pre.defined()) return pre; } T new_node = fmutate(node); if (f_postorder_ != nullptr) { T post = f_postorder_(new_node); if (post.defined()) return post; } return new_node; } // The functions const runtime::PackedFunc& f_preorder_; const runtime::PackedFunc& f_postorder_; // type indices enabled. const std::unordered_set& only_enable_; }; Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { only_type_index.insert(Object::TypeKey2Index(s.c_str())); } } IRTransformer transform(f_preorder, f_postorder, only_type_index); return transform(std::move(ir_node)); } class IRSubstitute : public StmtExprMutator { public: explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto ret = vmap_(var); if (ret.defined()) return ret.value(); return std::move(var); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { return Load(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); } else { return ret; } } Stmt VisitStmt_(const StoreNode* op) final { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { return Store(Downcast(mapped_var.value()), op->value, op->index, op->predicate); } else { return ret; } } Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); // remap var node in attr if (const auto* var_node = op->node.as()) { if (auto mapped_var = vmap_(GetRef(var_node))) { return AttrStmt(mapped_var, op->attr_key, op->value, op->body); } } return ret; } private: std::function(const Var&)> vmap_; }; Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(stmt)); } PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(expr)); } Array Substitute(const Array& region, const Map& vmap) { Array result; result.reserve(region.size()); for (const Range& range : region) { PrimExpr min = Substitute(range->min, vmap); PrimExpr extent = Substitute(range->extent, vmap); result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); } return result; } void PreOrderVisit(const ObjectRef& stmt_or_expr, const std::function& fvisit) { class PreOrderVisitor : public StmtExprVisitor { public: explicit PreOrderVisitor(const std::function& f) : f_(f) {} private: void VisitExpr(const PrimExpr& expr) final { const PrimExprNode* p_expr = expr.get(); if (visited_.count(p_expr) == 0) { visited_.insert(p_expr); if (f_(expr)) { ExprVisitor::VisitExpr(expr); } } } void VisitStmt(const Stmt& stmt) final { const StmtNode* p_stmt = stmt.get(); if (visited_.count(p_stmt) == 0) { visited_.insert(p_stmt); if (f_(stmt)) { StmtVisitor::VisitStmt(stmt); } } } const std::function& f_; std::unordered_set visited_; }; PreOrderVisitor visitor(fvisit); if (const auto* stmt = stmt_or_expr.as()) { visitor(GetRef(stmt)); } else if (const auto* expr = stmt_or_expr.as()) { visitor(GetRef(expr)); } else { LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " << stmt_or_expr->GetTypeKey(); } } TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); TVM_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); } else { return Substitute(Downcast(node), vmap); } }); } // namespace tir } // namespace tvm