/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file coproc_sync.cc */ #include #include #include #include #include #include #include #include "ir_utils.h" #include "storage_access.h" namespace tvm { namespace tir { // Visitor to find touched set by co-processor scope. class CoProcTouchedBuffer : public StmtExprVisitor { public: void VisitExpr_(const LoadNode* op) final { if (in_scope_) { touched_[op->buffer_var.get()].coproc = true; } else { touched_[op->buffer_var.get()].normal = true; } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const StoreNode* op) final { if (in_scope_) { touched_[op->buffer_var.get()].coproc = true; } else { touched_[op->buffer_var.get()].normal = true; } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buffer = op->args[1].as(); if (in_scope_) { touched_[buffer].coproc = true; } else { touched_[buffer].normal = true; } } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::coproc_scope && !in_scope_) { in_scope_ = true; IterVar iv = Downcast(op->node); coproc_.insert(iv); StmtExprVisitor::VisitStmt_(op); in_scope_ = false; } else { StmtExprVisitor::VisitStmt_(op); } } // Touch Entry struct TouchEntry { bool normal{false}; bool coproc{false}; }; std::unordered_map touched_; std::unordered_set coproc_; private: bool in_scope_{false}; }; // Synchronization planning with co-processor. class CoProcSyncPlanner : public StorageAccessVisitor { public: explicit CoProcSyncPlanner(const std::unordered_set& touched, const std::string& coproc_name) : touched_(touched), coproc_name_(coproc_name) {} void Plan(const Stmt& stmt) { this->VisitStmt(stmt); PlanSync(scope_.back(), nullptr, true); if (sync_.size() == 0) { sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync"); } } // Write synchronization to be inserted before or after stmt. std::unordered_map > sync_; protected: bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync std::vector Summarize(std::vector seq, const ForNode* loop) final { return PlanSync(seq, loop, false); } private: // Plan write synchronization if write is not coherent std::vector PlanSync(std::vector seq, const ForNode* loop, bool force_sync_at_end) { // detect write barriers // access by the co-processor. std::vector co_access; bool contain_sync = false; auto find_conflict = [&](const AccessEntry& acc) { for (const AccessEntry& x : co_access) { if (x.buffer.same_as(acc.buffer) && ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { return true; } } return false; }; for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; bool sync_write = false; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { sync_write = true; break; } if (acc.type == kSync) { co_access.clear(); contain_sync = true; } } if (sync_write) { ICHECK_NE(i, 0U); sync_[seq[i - 1].stmt] = GetSync(co_access); co_access.clear(); contain_sync = true; } for (const AccessEntry& acc : s.access) { if (acc.threads.size() != 0) { co_access.push_back(acc); } } } bool sync_at_end = force_sync_at_end; if (loop != nullptr && !sync_at_end) { // loop carray dependency for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { sync_at_end = true; break; } } if (sync_.count(s.stmt) || sync_at_end) break; } } if (sync_at_end && co_access.size() != 0) { ICHECK_NE(seq.size(), 0); contain_sync = true; sync_[seq.back().stmt] = GetSync(co_access); co_access.clear(); } if (contain_sync) { AccessEntry e; e.type = kSync; co_access.insert(co_access.begin(), e); } return co_access; } // Add write Synchronization std::vector GetSync(const std::vector& co_access) { // Does not consider memory coherence, need runtime. ICHECK_NE(co_access.size(), 0U); ICHECK_EQ(co_access[0].threads.size(), 1U); return GetSync(coproc_name_ + ".coproc_sync"); } std::vector GetSync(std::string sync_name) { return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))}; } const std::unordered_set& touched_; std::string coproc_name_; }; // Detect memory barriers when coproc read/write memory class CoProcBarrierDetector : public StorageAccessVisitor { public: explicit CoProcBarrierDetector(const std::unordered_set& touched, const std::string& coproc_name) : touched_(touched) { read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier"; write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier"; } void PlanReadBarrier(const Stmt& stmt) { read_barrier_ = true; this->VisitStmt(stmt); PlanReadBarrier(scope_.back(), nullptr); } void PlanWriteBarrier(const Stmt& stmt) { read_barrier_ = false; this->VisitStmt(stmt); PlanWriteBarrier(scope_.back(), nullptr); } std::unordered_map > barrier_before_; std::unordered_map > barrier_after_; protected: bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync std::vector Summarize(std::vector seq, const ForNode* loop) final { if (read_barrier_) { return PlanReadBarrier(seq, loop); } else { return PlanWriteBarrier(seq, loop); } } private: // Plan write barrier at Read after write point. std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { std::vector read_seq; std::unordered_map > write_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { auto it = write_set.find(acc.buffer.get()); if (it != write_set.end()) { ICHECK_NE(i, 0U); barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); write_set.erase(it); } }; for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && acc.type == kRead) { fupdate(i, acc); read_seq.push_back(acc); } } for (const AccessEntry& acc : s.access) { if (acc.threads.size() != 0 && acc.type == kWrite) { write_set[acc.buffer.get()].push_back(acc); } } } // loop carry if (loop != nullptr) { for (const AccessEntry& acc : read_seq) { fupdate(seq.size(), acc); } } for (const auto& kv : write_set) { read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); } return read_seq; } std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { std::vector write_seq; std::unordered_map > read_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { auto it = read_set.find(acc.buffer.get()); if (it != read_set.end()) { ICHECK_NE(i, seq.size()); barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); read_set.erase(it); } }; for (size_t i = seq.size(); i != 0; --i) { const StmtEntry& s = seq[i - 1]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && acc.type == kWrite) { fupdate(i, acc); write_seq.push_back(acc); } } for (const AccessEntry& acc : s.access) { if (acc.threads.size() != 0 && acc.type == kRead) { read_set[acc.buffer.get()].push_back(acc); } } } // loop carry if (loop != nullptr) { for (const AccessEntry& acc : write_seq) { fupdate(0, acc); } } for (const auto& kv : read_set) { write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); } return write_seq; } Stmt MakeBarrier(const std::string& func, const std::vector& wvec) { // insert write point Array wset; for (const AccessEntry& acc : wvec) { ICHECK(acc.dtype == wvec[0].dtype); wset.push_back(acc.touched); } Range none; Range r = arith::Union(wset).CoverRange(none); ICHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; return Evaluate(Call(DataType::Int(32), Op::Get(func), {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent})); } // Write barrier name bool read_barrier_{false}; std::string read_barrier_name_; std::string write_barrier_name_; const std::unordered_set& touched_; }; class CoProcInstDepDetector : public StmtVisitor { public: explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push"); sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop"); } void Plan(const Stmt& stmt) { this->VisitStmt(stmt); if (last_state_.node != nullptr) { MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); } } void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { const IntImmNode* ctx_id = op->value.as(); ICHECK(ctx_id != nullptr); curr_state_.clear(); curr_state_.node = op->body.get(); curr_state_.enter_ctx.insert(ctx_id->value); curr_state_.exit_ctx.insert(ctx_id->value); UpdateState(); } else { StmtVisitor::VisitStmt_(op); } } void VisitStmt_(const ForNode* op) final { SyncState temp_first, temp_last; std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); this->VisitStmt(op->body); curr_state_.clear(); if (last_state_.node != nullptr) { curr_state_.node = op; ICHECK(first_state_.node != nullptr); // loop carry dependency InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); curr_state_.enter_ctx = first_state_.enter_ctx; curr_state_.exit_ctx = last_state_.exit_ctx; } std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); if (curr_state_.node != nullptr) { UpdateState(); } } void VisitStmt_(const IfThenElseNode* op) final { SyncState temp_first, temp_last, curr_state; std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); { // then stmt this->VisitStmt(op->then_case); if (last_state_.node != nullptr) { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } first_state_.clear(); last_state_.clear(); } if (op->else_case.defined()) { this->VisitStmt(op->else_case); if (last_state_.node != nullptr) { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } } // update in the trace. std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); std::swap(curr_state_, curr_state); if (curr_state_.node != nullptr) { UpdateState(); } } void VisitStmt_(const WhileNode* op) final { // TODO(masahi): Do we need a special handling for While nodes? LOG(FATAL) << "WhileNode not supported in CoProcSync."; } // insert before is stored in reverse order // the first element is closest to the node. std::unordered_map > insert_before_; std::unordered_map > insert_after_; private: // state in the sync entry struct SyncState { // The statement of the state. const Object* node{nullptr}; // Set of all possible contexts in the entering moment. std::unordered_set enter_ctx; // Set of all possible contexts in the exit moment. std::unordered_set exit_ctx; // existing pop performed at enter std::vector > enter_pop; // existing push peformed at exit std::vector > exit_push; // clear the state void clear() { node = nullptr; enter_ctx.clear(); exit_ctx.clear(); enter_pop.clear(); exit_push.clear(); } }; // inject proper sync into the pair // record the push/pop sequence that could be possibly un-matched. // return the push/pop message at enter/exit of the Block // after considering the existing unmatcheded events and added events void InjectSync(const SyncState& prev, const SyncState& next, std::vector >* prev_exit_push, std::vector >* next_enter_pop) { prev_exit_push->clear(); next_enter_pop->clear(); // quick path if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) { int from = *prev.exit_ctx.begin(); int to = *next.enter_ctx.begin(); if (from != to) { insert_after_[prev.node].emplace_back(MakePush(from, to)); insert_before_[next.node].emplace_back(MakePop(from, to)); prev_exit_push->emplace_back(std::make_pair(from, to)); next_enter_pop->emplace_back(std::make_pair(from, to)); } return; } // complicate path. std::vector > vpush = prev.exit_push; std::vector > vpop = next.enter_pop; std::vector > pending; for (int from : prev.exit_ctx) { for (int to : next.enter_ctx) { if (from != to) { pending.emplace_back(std::make_pair(from, to)); } } } // policy 1 std::vector prev_after, next_before; for (const std::pair& p : pending) { if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { vpush.push_back(p); prev_after.emplace_back(MakePush(p.first, p.second)); } if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { vpop.push_back(p); next_before.emplace_back(MakePop(p.first, p.second)); } } // fix pending for (const std::pair& p : vpush) { if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) { prev_after.emplace_back(MakePop(p.first, p.second)); } else { prev_exit_push->push_back(p); } } for (const std::pair& p : vpop) { if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) { next_before.emplace_back(MakePush(p.first, p.second)); } else { next_enter_pop->push_back(p); } } if (prev_after.size() != 0) { auto& v1 = insert_after_[prev.node]; v1.insert(v1.end(), prev_after.begin(), prev_after.end()); } if (next_before.size() != 0) { auto& v2 = insert_before_[next.node]; v2.insert(v2.end(), next_before.begin(), next_before.end()); } } void MatchFixEnterPop(const SyncState& state) { if (state.enter_pop.size() == 0) return; auto& vec = insert_before_[state.node]; for (const std::pair& p : state.enter_pop) { vec.push_back(MakePush(p.first, p.second)); } } void MatchFixExitPush(const SyncState& state) { if (state.exit_push.size() == 0) return; auto& vec = insert_after_[state.node]; for (const std::pair& p : state.exit_push) { vec.push_back(MakePop(p.first, p.second)); } } void UpdateState() { if (last_state_.node != nullptr) { std::vector > t1, t2; InjectSync(last_state_, curr_state_, &t1, &t2); std::swap(last_state_, curr_state_); } else { ICHECK(first_state_.node == nullptr); first_state_ = curr_state_; last_state_ = curr_state_; } } Stmt MakePush(int from, int to) { return Evaluate(Call(DataType::Int(32), sync_push_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); } Stmt MakePop(int from, int to) { return Evaluate(Call(DataType::Int(32), sync_pop_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); } // sync states. SyncState first_state_, last_state_, curr_state_; // Variables IterVar coproc_axis_; Op sync_push_op_, sync_pop_op_; }; class CoProcSyncInserter : public StmtMutator { public: Stmt Insert(Stmt stmt) { CoProcTouchedBuffer visitor; visitor(stmt); if (visitor.coproc_.size() == 0) return stmt; std::unordered_set touched; for (const auto& kv : visitor.touched_) { if (kv.second.normal && kv.second.coproc) { touched.insert(kv.first); } } ICHECK_EQ(visitor.coproc_.size(), 1U); std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint; // plan sync. CoProcSyncPlanner sync_planner(touched, coproc_name); sync_planner.Plan(stmt); for (const auto& kv : sync_planner.sync_) { auto& vec = insert_after_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } // Detect barrier CoProcBarrierDetector barrier_detector(touched, coproc_name); barrier_detector.PlanReadBarrier(stmt); barrier_detector.PlanWriteBarrier(stmt); for (const auto& kv : barrier_detector.barrier_before_) { auto& vec = insert_before_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } for (const auto& kv : barrier_detector.barrier_after_) { auto& vec = insert_after_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } // Detect barrier CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); sync_detector.Plan(stmt); for (const auto& kv : sync_detector.insert_before_) { auto& vec = insert_before_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } for (const auto& kv : sync_detector.insert_after_) { auto& vec = insert_after_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } return operator()(std::move(stmt)); } Stmt VisitStmt(const Stmt& stmt) final { auto it_before = insert_before_.find(stmt.get()); auto it_after = insert_after_.find(stmt.get()); Stmt new_stmt = StmtMutator::VisitStmt(stmt); return SeqStmt::Flatten( it_before != insert_before_.end() ? it_before->second : std::vector(), new_stmt, it_after != insert_after_.end() ? it_after->second : std::vector()); } private: // insert before is stored in reverse order // the first element is closest to the node. std::unordered_map > insert_before_; std::unordered_map > insert_after_; }; Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } namespace transform { Pass CoProcSync() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = CoProcSyncInserter().Insert(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); } TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync); } // namespace transform } // namespace tir } // namespace tvm