/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \brief Lift specified AttrStmt scope to outer if * the body contains the same scope. * \file lift_attr_scope.cc */ #include #include #include #include "ir_utils.h" namespace tvm { namespace tir { // NOTE: this optimization can only be applied // to a few specified attr keys class AttrScopeLifter : public StmtMutator { public: explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt); } return stmt; } // do not go beyond Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (attr_node_.defined()) { Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); } else { return stmt; } } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr_key_) { attr_node_ = op->node; attr_value_ = op->value; return op->body; } else { return StmtMutator::VisitStmt_(op); } } Stmt VisitStmt_(const SeqStmtNode* op) final { // remember the decorations. std::vector attr_node; std::vector attr_value; auto fmutate = [&](const Stmt& s) { attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); Stmt ret = this->VisitStmt(s); attr_node.push_back(attr_node_); attr_value.push_back(attr_value_); return ret; }; Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate); if (attr_node.size() == 0) return ret; op = ret.as(); ICHECK(op != nullptr); Array reorg; // check if all decorations are common. for (size_t begin = 0; begin < attr_node.size();) { size_t end = begin + 1; while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && ValueSame(attr_value[end], attr_value[begin])) { ++end; } // covers everything // lift attr to parent. if (begin == 0 && end == attr_node.size()) { attr_node_ = attr_node[0]; attr_value_ = attr_value[0]; return ret; } // construct subsegments. Array seq; for (size_t i = begin; i < end; ++i) { seq.push_back(op->seq[i]); } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); begin = end; } attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); return SeqStmt::Flatten(reorg); } Stmt VisitStmt_(const IfThenElseNode* op) final { if (!op->else_case.defined()) { return StmtMutator::VisitStmt_(op); } Stmt then_case = this->VisitStmt(op->then_case); ObjectRef first_node; PrimExpr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->VisitStmt(op->else_case); if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElse(op->condition, then_case, else_case); } } else { if (first_node.defined()) { then_case = AttrStmt(first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); } if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElse(op->condition, then_case, else_case); } } } Stmt VisitStmt_(const WhileNode* op) final { // TODO(masahi): Do we need a special handling for While nodes? LOG(FATAL) << "WhileNode not supported in LiftAttrScope."; return Stmt(); } private: // value comparison that also compares content of int constant static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { if (a.same_as(b)) return true; if (!a.defined() || !b.defined()) return false; if (a->type_index() != b->type_index()) return false; if (a.dtype() != b.dtype()) return false; if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } return false; } std::string attr_key_; ObjectRef attr_node_; PrimExpr attr_value_; }; Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } namespace transform { Pass LiftAttrScope(String attr_key) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); } TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope); } // namespace transform } // namespace tir } // namespace tvm