/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Lower warp memory to use local memory * and shuffle intrinsics. * * \file lower_warp_memory.cc */ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. #include #include #include #include #include #include #include #include #include #include #include #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" #include "update_pointer_storage_scope.h" namespace tvm { namespace tir { // Rewrite Rule // // There is no special warp memory in most GPUs. // Instead, we can stripe the data into threads // and store the data into local memory. // // This requires us to do the following rewriting: // - Rewrite allocation to use local memory. // - Rewrite store of warp memory to local store. // - Rewrite load of warp memory to local plus a shuffle. // // Define a generic shuffle intrinsic warp_shuffle(data, warp_index). // We can use the following rewriting rule // // Before rewrite, // // alloc warp warp_mem[n * width * m] // store warp_mem[m * warp_index + (width * m) * y + x] // load warp_mem[m * z + (width * m) * y + x] // subject to x \in [0, m), y \in [0, n) // // where width equals to the extent of threadIdx.x, which should // be no larger than the warp size // // After rewrite: // // alloc local local_mem[n * m] // store warp_mem[m * y + x] // warp_shuffle(load warp_mem[m * y + x], z) // subject to (m * y + x) is invariant to warp_index // // If width == warp size, we are shuffling on full warps. // Otherwise, we are virtually shuffling on sub-warps, // whose size equals to width. In this case, you can imagine // a warp only consists of `width` threads. Width is passed // as an argument to the shuffle primitive, and will be // lowered to the device code if the target supports. // // A limitation of this sub-warp approach is that users // cannot shuffle across the sub-warp boundary (i.e. shuffle // with threadIdx.y or threadIdx.z indices). It can be solved // via fusing threadIdx.x to the warp size, or improving the // analyzer to detect both 3 thread axes, which is left for // future improvements. // Algorithm // // To implement this rewrite rule, we can do the follow step: // For each warp memory alloc // - Use linear pattern detector on load index to find m // - Deduce n given width and alloc size // - Now that we have m, n, width, we can proceed with the rewrite // Visitor to find m in pattern // store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { this->VisitStmt(stmt); return warp_coeff_; } private: /// Visitor implementation void VisitStmt_(const StoreNode* op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); } else { arith::PVar base; ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; UpdatePattern(base.Eval()); } } else { StmtVisitor::VisitStmt_(op); } } void UpdatePattern(const PrimExpr& index) { Array m = arith::DetectLinearEquation(index, {warp_index_}); ICHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed. Could not simplify the store index `" << index << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " "thread local registers and shuffling values between these registers. Currently only " "linear equation indices are supported."; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); const auto* mcoeff_as_int = mcoeff.as(); ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index << ", require positive constant coefficient on warp index " << warp_index_ << " but get " << mcoeff; if (warp_coeff_ != 0) { ICHECK_EQ(warp_coeff_, mcoeff_as_int->value) << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { warp_coeff_ = mcoeff_as_int->value; } } // The buffer variable const VarNode* buffer_; // the warp index Var warp_index_; // the coefficient int64_t warp_coeff_{0}; // analyzer. arith::Analyzer* analyzer_; }; // Visitor to find the warp index class WarpIndexFinder : private StmtVisitor { public: explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} // find the warp co-efficient and the shuffle width in the statement std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); ICHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; return std::make_pair(warp_index_->var, width_); } private: /// Visitor implementation void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { auto* value_as_int = op->value.as(); ICHECK(value_as_int && value_as_int->value <= warp_size_ && warp_size_ % value_as_int->value == 0) << "Expect threadIdx.x 's size to be no larger than, and a factor of" << " warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { ICHECK(warp_index_.same_as(iv)) << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { width_ = value_as_int->value; warp_index_ = iv; } } } StmtVisitor::VisitStmt_(op); } // warp size int warp_size_{0}; // number of threads involved in one shuffle int width_{0}; // the warp index IterVar warp_index_{nullptr}; }; // Mutator to change the read pattern class WarpAccessRewriter : protected StmtExprMutator { public: explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the allocate statement which transforms // warp memory to local memory. Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); // Align the local memory size. The number of elements may not // be a multiple of width_ * warp_coeff_; round it up. int factor = width_ * warp_coeff_; ICHECK_NE(factor, 0) << "Divide by zero"; warp_group_ = (alloc_size + (factor - 1)) / factor; alloc_size = warp_group_ * factor; return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, this->VisitStmt(op->body)); } protected: PrimExpr VisitExpr_(const VarNode* op) override { ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) override { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); PrimExpr new_value = VisitExpr(op->value); return Store(op->buffer_var, new_value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } } PrimExpr VisitExpr_(const LoadNode* op) override { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); if (analyzer_->CanProveEqual(group, warp_index_)) { return load_value; } PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), {mask, load_value, group, width_, warp_size_}); } else { return StmtExprMutator::VisitExpr_(op); } } // Split the index to the two component // // local index is the index in the local // source index is the corresponding source index // in this access pattern. std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { PrimExpr local_index, group; arith::PVar base; ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); std::tie(local_index, group) = SplitIndexByGroup(base.Eval()); local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } PrimExpr m = make_const(index.dtype(), warp_coeff_); // simple case, warp index is on the highest. if (warp_group_ == 1) { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m)); return std::make_pair(x, z); } else { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); y = y * m + x; PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); } } private: // the warp size int warp_size_{0}; // The buffer variable const VarNode* buffer_; // number of threads involved in one shuffle int width_{0}; // Warp index Var warp_index_; // the coefficient m int warp_coeff_{0}; // the coefficient n int warp_group_{0}; // Internal analyzer arith::Analyzer* analyzer_; }; // Bind bound information of variables to make analyzer more effective // TODO(tqchen): consider a pass to inline the bound info into the expr // so analysis can be context independent. class BindVarBoundInfo : public StmtVisitor { public: explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; analyzer_->Bind(loop_var, Range::FromMinExtent(op->min, op->extent)); StmtVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); if (!var_dom_.count(iv->var.get())) { Range dom = Range::FromMinExtent(0, op->value); var_dom_[iv->var.get()] = dom; analyzer_->Bind(iv->var, dom); } } StmtVisitor::VisitStmt_(op); } protected: // internal analyzer. arith::Analyzer* analyzer_; // variable domain std::unordered_map var_dom_; }; // Mutator to change the read pattern class WarpMemoryRewriter : private StmtMutator { public: explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {} Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; BindVarBoundInfo binder(&analyzer_); binder(stmt); stmt = operator()(std::move(stmt)); return stmt; } std::unordered_map new_storage_scopes_; private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); if (GetPtrStorageScope(op->buffer_var) == "warp") { new_storage_scopes_[op->buffer_var.get()] = "local"; WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } int warp_size_{0}; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; }; namespace transform { Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform } // namespace tir } // namespace tvm