/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ #include <tvm/tir/builtin.h> #include "compute_op.h" #include "op_utils.h" namespace tvm { namespace te { using namespace tir; // // Cross thread reduction transformation. // // The input loop nest in generic form (single reduction/thread case) // // let m be the reduction extent // let N be the thread extent // let input_pred be the predicate on the reduction // // B[..] = 0 // for (tid, 0, N) // for (i, 0, floordiv(m+N-1, N)) // if (i + tid * floordiv(m+N-1, N) < m) // if (input_pred) // B[..] = op(B[..], A[i + tid * floordiv(m+N-1,N)]) // // The threaded reduction looks like // // (1) normal reductions (leaves) // for (i, 0, floordiv(m+N-1, N)) // if (i + tid * floordiv(m+N-1, N) < m) // if (input_pred) // B_temp[0] = op(B_temp[0], A[i + tid * floordiv(m+N-1,N)]) // // (2) threaded reduction does not require predicates as an identity // element will be filled if out of bounds. // // tvm_thread_allreduce(size, B_temp, (bool)1, tid) // // The last step is to write the final reduction variable, // which should be predicated by the existing input_pred if any // The consequence is that input_pred should be independent of // the reduction axis. Otherwise, we need to seperate it into // dependent part and independent one. // // (3) write back // if (input_pred) // B[..] = B_temp[0] // // In summary, we are going to need two predicates // // * the original input_pred from reduction itself // // * the normal reduction axis predicate // normal_pred = (i + tid * floordiv(m+N-1,N)) < m // this predicate depends on the normal reduction variable. // // input_pred will be applied to both normal reduction and // the writeback step. // Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop) { Array<PrimExpr> args; for (IterVar iv : self->axis) { args.push_back(iv->var); } std::unordered_map<IterVar, PrimExpr> value_map; auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop); size_t size = self->body.size(); ICHECK_GT(size, 0); std::vector<const ReduceNode*> reduces(size); for (size_t i = 0; i < size; ++i) { const ReduceNode* reduce = self->body[i].as<ReduceNode>(); ICHECK(reduce); ICHECK(reduce->init.empty()) << "Cannot perform cross_thread_reduction for reductions with init"; reduces[i] = reduce; } // This computes the bound checking predicates in normal reduction. auto normal_preds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>()); // normal_pred = input_pred && normal_pred PrimExpr input_pred = reduces[0]->condition; normal_preds.push_back(input_pred); normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(), [](const PrimExpr& e) { return !e.defined(); }), normal_preds.end()); std::vector<std::vector<Stmt>> common, normal_red; for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) { IterVar iv = stage->leaf_iter_vars[i]; IterVarAttr attr; auto it = stage->iter_var_attrs.find(iv); if (it != stage->iter_var_attrs.end()) { attr = (*it).second; } if (iv->iter_type == kCommReduce) { if (attr.defined() && attr->bind_thread.defined()) { common.emplace_back(nest[i + 1]); } else { normal_red.emplace_back(nest[i + 1]); } } else { common.emplace_back(nest[i + 1]); } } // If we load from and then store into the same res_handles in the thread_allreduce intrinsic, // something goes wrong, so we use an extra variable here for normal reduction. std::vector<Var> normal_res_handles; std::vector<Stmt> normal_init, normal_update; if (!normal_red.empty()) { normal_res_handles.reserve(size); normal_init.reserve(size); normal_update.resize(size); const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>(); ICHECK(combiner); Array<PrimExpr> lhs; for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array<PrimExpr> init_value = combiner->identity_element; Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_init.emplace_back( Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); normal_update.emplace_back( Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); } } Array<PrimExpr> freduce_args; freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size))); for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { DataType t = reduces[i]->dtype; freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } else { freduce_args.push_back(reduces[0]->source[i]); } } // No constraints on the thread reduction step. It may have redundent // computation for rare cases. TODO(tvm-team): revisit this. freduce_args.push_back(const_true(1)); std::vector<Var> res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { auto it = stage->iter_var_attrs.find(iv); if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { IterVar tv = (*it).second->bind_thread; freduce_args.push_back(tv->var); } } } // Checks for the thread. std::vector<PrimExpr> output_preds; if (stage->store_predicate.defined()) { output_preds.emplace_back(stage->store_predicate); } // Apply the existing input predicate if any. output_preds.push_back(input_pred); Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args)); reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); if (!normal_red.empty()) { Stmt init_body = SeqStmt::Flatten(normal_init); Stmt update_body = SeqStmt::Flatten(normal_update); update_body = MergeNest(MakeIfNest(normal_preds), update_body); update_body = MergeNest(normal_red, update_body); reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body); } std::vector<Stmt> assigns(size); for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; assigns[idx] = ProducerStore(stage->op.output(idx), Load(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); } } body = Substitute(body, value_map); return MergeNest(common, body); } } // namespace te } // namespace tvm