/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tir/analysis/buffer_access_lca_detector.cc * \brief Detect the lowest common ancestor(LCA) of buffer access */ #include <tvm/tir/analysis.h> #include <tvm/tir/stmt_functor.h> #include "../../support/arena.h" namespace tvm { namespace tir { /*! * \brief Detect the lowest common ancestor(LCA) position of Buffer access. * \note Only consider BlockNode and ForNode to be the LCA nodes. */ class LCADetector : public StmtExprVisitor { public: static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) { LCADetector detector; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); } detector(func->body); // Prepare the return Map<Buffer, Optional<Stmt>> buffer_lca; for (const auto& kv : detector.buffer_lca_) { const Buffer& buffer = GetRef<Buffer>(kv.first); const Optional<Stmt> stmt = kv.second ? GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt; buffer_lca.Set(buffer, stmt); } return buffer_lca; } private: /*! * \brief The AST node information for querying LCA. * \note Only BlockNode and ForNode are considered, since they are the only statements whose * body can be a SeqStmt (the LCA of buffer access) in TensorIR. */ struct ScopeInfo { // The parent scope info const ScopeInfo* parent_scope_info; // The parent scope stmt node const StmtNode* stmt; // The scope depth in the AST int depth; ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth) : parent_scope_info(parent_info), stmt(stmt), depth(depth) {} }; void VisitStmt_(const ForNode* op) final { int n = ancestor_scopes_.size(); const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n); ancestor_scopes_.push_back(current_scope); StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } void VisitStmt_(const BlockNode* op) final { int n = ancestor_scopes_.size(); for (const Buffer& buf : op->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n); ancestor_scopes_.push_back(current_scope); // Update match_buffers for (const MatchBufferRegion& match_buffer : op->match_buffers) { UpdateBufferLCA(match_buffer->source->buffer.get()); match_buffers_.insert(match_buffer->buffer.get()); } StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } void VisitExpr_(const BufferLoadNode* op) final { UpdateBufferLCA(op->buffer.get()); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { UpdateBufferLCA(op->buffer.get()); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const BufferRealizeNode* op) final { buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get()); StmtExprVisitor::VisitStmt_(op); } // Works for Load/Store and opaque access. void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } // Explict to visit buffer data in Load and Store node. void VisitExpr_(const LoadNode* op) final { ExprVisitor::VisitExpr_(op); VisitBufferVar(op->buffer_var.get()); } void VisitStmt_(const StoreNode* op) final { StmtVisitor::VisitStmt_(op); VisitBufferVar(op->buffer_var.get()); } void VisitBufferVar(const VarNode* op) { auto it = buffer_var_map_.find(op); if (it != buffer_var_map_.end()) { UpdateBufferLCA(it->second); } } void UpdateBufferLCA(const BufferNode* buffer) { if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { if (lhs == nullptr) return rhs; if (rhs == nullptr) return lhs; while (lhs->parent_scope_info != nullptr && // rhs->parent_scope_info != nullptr && // lhs != rhs) { if (lhs->depth == rhs->depth) { lhs = lhs->parent_scope_info; rhs = rhs->parent_scope_info; } else if (lhs->depth < rhs->depth) { rhs = rhs->parent_scope_info; } else { lhs = lhs->parent_scope_info; } } if (lhs->parent_scope_info == nullptr) { return lhs; } if (rhs->parent_scope_info == nullptr) { return rhs; } ICHECK(lhs == rhs); return lhs; } /*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */ std::vector<const ScopeInfo*> ancestor_scopes_ = {nullptr}; /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {}; /*! \brief The match buffers inside blocks. */ std::unordered_set<const BufferNode*> match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); } // namespace tir } // namespace tvm