/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "./utils.h" namespace tvm { namespace tir { /******** Utility functions ********/ template using SMap = std::unordered_map; /*! * \brief Add a dependency relation. * \param src The source of the dependency * \param dst The destination of the dependecy * \param kind Type of the dependency * \note This method is effectively NOP on self-loops */ void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { if (!src.same_as(dst)) { Dependency dep(src, dst, kind); self->src2deps[src].push_back(dep); self->dst2deps[dst].push_back(dep); } } /******** Constructors ********/ StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) { ObjectPtr n = make_object(); n->stmt = stmt; n->parent = parent; n->seq_index = seq_index; data_ = std::move(n); } StmtSRef StmtSRef::InlineMark() { static StmtSRef result(nullptr, nullptr, -1); return result; } StmtSRef StmtSRef::RootMark() { static StmtSRef result(nullptr, nullptr, -1); return result; } Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { ObjectPtr node = make_object(); node->src = std::move(src); node->dst = std::move(dst); node->kind = kind; data_ = std::move(node); } BlockScope::BlockScope() { data_ = make_object(); } BlockScope::BlockScope(const Array& child_block_srefs) { ObjectPtr n = make_object(); SMap> buffer_readers; SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer for (const BufferRegion& region : child_block->reads) { buffer_readers[region->buffer].push_back(child_block_sref); } for (const BufferRegion& region : child_block->writes) { buffer_writers[region->buffer].push_back(child_block_sref); } // Step 2. Update RAW dependency for (const BufferRegion& region : child_block->reads) { auto it = buffer_writers.find(region->buffer); if (it != buffer_writers.end()) { for (const StmtSRef& from : it->second) { AddDependency(n.get(), from, child_block_sref, DepKind::kRAW); } } } // Step 3. Update WAW dependency for (const BufferRegion& region : child_block->writes) { auto it = buffer_writers.find(region->buffer); if (it != buffer_writers.end()) { for (const StmtSRef& from : it->second) { AddDependency(n.get(), from, child_block_sref, DepKind::kWAW); } } } // Step 4. Update WAR dependency for (const BufferRegion& region : child_block->writes) { auto it = buffer_readers.find(region->buffer); if (it != buffer_readers.end()) { for (const StmtSRef& from : it->second) { AddDependency(n.get(), from, child_block_sref, DepKind::kWAR); } } } } data_ = std::move(n); } /******** Dependency ********/ Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { auto iter = this->src2deps.find(block_sref); if (iter != this->src2deps.end()) { return iter->second; } else { return {}; } } Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { auto iter = this->dst2deps.find(block_sref); if (iter != this->dst2deps.end()) { return iter->second; } else { return {}; } } /******** FFI ********/ TVM_REGISTER_NODE_TYPE(StmtSRefNode); TVM_REGISTER_NODE_TYPE(DependencyNode); TVM_REGISTER_NODE_TYPE(BlockScopeNode); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt") .set_body_typed([](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent") .set_body_typed([](StmtSRef sref) -> Optional { return GetRef>(sref->parent); }); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") // .set_body_typed(StmtSRef::RootMark); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") // .set_body_typed(StmtSRef::InlineMark); TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc") .set_body_method(&BlockScopeNode::GetDepsBySrc); TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst") .set_body_method(&BlockScopeNode::GetDepsByDst); } // namespace tir } // namespace tvm