/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file auto_scheduler/compute_dag.cc * \brief Compute declaration graph and its related analysis tools. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../arith/pattern_match.h" #include "../relay/transforms/auto_scheduler_layout_rewrite.h" #include "search_policy/utils.h" #include "utils.h" namespace tvm { namespace auto_scheduler { using namespace tvm::tir; template using OperationMap = AccessAnalyzerNode::OperationMap; using OperationSet = std::unordered_set; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Topo-sort ops from tensors according to their read-write relations. Array TopoSortOps(const Array& tensors) { std::unordered_map degree; std::unordered_map> edge_set; std::unordered_map priority; std::unordered_set visited; // traverse to build edge_set and count degree std::vector stack; stack.reserve(tensors.size()); for (const auto& x : tensors) { stack.push_back(x->op.operator->()); } int ct = 0; while (!stack.empty()) { const te::OperationNode* op = stack.back(); stack.pop_back(); if (visited.count(op)) { continue; } priority[op] = ct; ct++; visited.insert(op); if (op->IsInstance()) { degree[op] = 0; } else if (auto cop = GetRef(op).as()) { const Array& input_tensors = cop->InputTensors(); degree[op] = input_tensors.size(); for (const auto& ten : input_tensors) { edge_set[ten->op.operator->()].push_back(op); stack.push_back(ten->op.operator->()); } } else { LOG(FATAL) << "Unsupported op " << GetRef(op); } } // topo sort Array ops; using Item = std::pair; auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; }; std::priority_queue, decltype(cmp)> queue(cmp); for (const auto& iter : degree) { if (iter.second == 0) { queue.push(Item(iter.first, priority[iter.first])); } } ops.reserve(degree.size()); while (!queue.empty()) { Item item = queue.top(); queue.pop(); ops.push_back(GetRef(item.first)); for (const auto& dst : edge_set[item.first]) { degree[dst] -= 1; if (degree[dst] == 0) { queue.push(Item(dst, priority[dst])); } } } return ops; } // Extract all tensor accesses in an expr class ReadAccessExtractor : public StmtExprVisitor { public: void Extract(PrimExpr expr) { this->VisitExpr(expr); } void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::if_then_else())) { has_branch = true; } StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const ProducerLoadNode* op) final { read_access[Downcast(op->producer)->op].emplace_back(op->indices.begin(), op->indices.end()); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const IfThenElseNode* op) final { has_branch = true; StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const SelectNode* op) final { has_branch = true; StmtExprVisitor::VisitExpr_(op); } // All read accesses to all operations // The innermost vector stores multi-dimensional indices. // The middle vector stores possible multiple accesses OperationMap>> read_access; // Whether this expression has branch bool has_branch{false}; }; // Returns whether the expr equals to the var with an optional const shift bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { arith::PVar x; arith::PVar c; if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) && x.Eval().same_as(var)) { return true; } return false; } // Return whether the access to an operation is a simple access // (i.e. all index is just a variable with an optional constant shift) // For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not. bool IsSimpleAccess(const te::Operation& op, const std::vector& indices, bool* axis_missing, bool* axis_duplicated, bool* same_order) { auto cop = op.as(); if (cop == nullptr) { return false; } std::vector index_to_var_idx; std::vector var_idx_ct(cop->axis.size(), 0); for (const auto& expr : indices) { if (!is_const_int(expr)) { bool found = false; for (size_t i = 0; i < cop->axis.size(); ++i) { if (IsConstShiftEqual(cop->axis[i]->var, expr)) { index_to_var_idx.push_back(i); var_idx_ct[i]++; found = true; break; } } if (!found) { return false; } } } *axis_missing = false; // Some axes are missing *axis_duplicated = false; // Some axes appear more than once *same_order = true; // The axis order is the same as op->axis for (int ct : var_idx_ct) { if (ct == 0) { *axis_missing = true; } else if (ct > 1) { *axis_duplicated = true; } } for (size_t i = 1; i < index_to_var_idx.size(); ++i) { if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { *same_order = false; break; } } return true; } // Gather all VarNodes in an expr void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { PostOrderVisit(expr, [&vars](const ObjectRef& node) { if (const VarNode* op = node.as()) { vars->insert(op); } }); } // Check whether an expr has expensive operations (e.g. exp) bool HasExpensiveOp(const PrimExpr& expr) { bool found = false; PostOrderVisit(expr, [&found](const ObjectRef& node) { if (const CallNode* op = node.as()) { if (op->op.as()->name == "tir.exp") { found = true; } } }); return found; } AccessAnalyzer::AccessAnalyzer(const Array& tensors) { auto node = make_object(); OperationMap has_branch; // Get all ops in topological order node->ops_topo_order = TopoSortOps(tensors); arith::Analyzer analyzer; // Build read & write access map for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { node->read_from[op] = OperationMap>>(); } else if (auto cop = op.as()) { ReadAccessExtractor extractor; for (const auto& exp : cop->body) { extractor.Extract(exp); } // read_by and read_from map for (const auto& iter : extractor.read_access) { std::vector>& accesses = node->read_by[iter.first][op]; accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); } node->read_from[op] = std::move(extractor.read_access); has_branch[op] = extractor.has_branch; // compute number of common outer iterators for (const auto& pair : node->read_from[op]) { const te::Operation& producer = pair.first; const std::vector>& access_list = pair.second; const Array& output_shape = op->output_shape(0); const Array& producer_shape = producer->output_shape(0); int n_common; for (n_common = 0; n_common < static_cast(std::min(output_shape.size(), producer_shape.size())); n_common++) { if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) { break; } bool injective = true; for (const auto& access : access_list) { if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { injective = false; break; } } if (!injective) { break; } } node->num_common_outer_iterators[op][producer] = n_common; node->num_common_outer_iterators[producer][op] = n_common; } } else { LOG(FATAL) << "Invalid op: " << op; } } // Do some static analysis on ComputeOps for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { node->is_simple_access[op] = true; node->needs_multi_level_tiling[op] = false; node->is_strictly_inlineable[op] = false; node->is_output[op] = false; } else if (auto cop = op.as()) { // check whether this op is element-wise and strict-inlineable bool is_simple_access = true; bool is_strictly_inlineable = true; bool axis_missing, axis_duplicated, same_order; for (const auto& pair : node->read_from[op]) { const std::vector>& access_list = pair.second; for (const auto& access : access_list) { if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, &same_order)) { is_simple_access = false; is_strictly_inlineable = false; break; } if (!same_order || axis_duplicated) { // do not strictly inline transpose is_strictly_inlineable = false; } } if (!is_simple_access) { break; } } // don't strictly inline expensive op (e.g. exp) bool has_expensive_op = false; for (const auto& expr : cop->body) { has_expensive_op |= HasExpensiveOp(expr); } if (has_expensive_op || has_branch[op]) { is_strictly_inlineable = false; } // constant tensor is strict-inlineable if (node->read_from[op].empty()) { is_strictly_inlineable = true; } node->is_simple_access[op] = is_simple_access; node->is_strictly_inlineable[op] = is_strictly_inlineable; // check whether the op needs multi-level tiling bool needs_multi_level_tiling = false; int n_missing = 0; for (const auto& pair : node->read_from[op]) { const std::vector>& access_list = pair.second; std::unordered_set vars; for (const std::vector& access : access_list) { for (const PrimExpr& expr : access) { GatherVars(expr, &vars); } } for (const auto& axis : cop->axis) { if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { n_missing++; break; } } if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) { needs_multi_level_tiling = true; break; } } // do not perform multi-level tiling on "fake reduction" with const tensors if (op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)) { needs_multi_level_tiling = false; } node->needs_multi_level_tiling[op] = needs_multi_level_tiling; // check whether the op is output node->is_output[op] = node->read_by[op].empty(); } else { LOG(FATAL) << "Invalid op" << op; } } data_ = std::move(node); } bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const { return operator->()->needs_multi_level_tiling.at(op); } bool AccessAnalyzer::IsOutput(const te::Operation& op) const { return operator->()->is_output.at(op); } bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { return operator->()->is_simple_access.at(op); } bool AccessAnalyzer::IsStrictlyInlineable(const te::Operation& op) const { return operator->()->is_strictly_inlineable.at(op); } OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { inlined_ops.insert(stage->op); } } OperationSet consumers; std::function collect; collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { for (const auto& iter : operator->()->read_by.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { consumers.insert(iter.first); } } }; collect(op); return consumers; } OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const { OperationSet producers; for (const auto& iter : operator->()->read_from.at(op)) { producers.insert(iter.first); } return producers; } OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { inlined_ops.insert(stage->op); } } OperationSet producers; std::function collect; collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { for (const auto& iter : operator->()->read_from.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { producers.insert(iter.first); } } }; collect(op); return producers; } int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, const te::Operation& target_op) const { int ret = INT32_MAX; bool meet = false; std::function traverse; traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) { if (cur_op == target_op) { ret = std::min(ret, cur_num); meet = true; return; } for (const auto& iter : operator->()->read_by.at(cur_op)) { traverse( iter.first, std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first))); } }; traverse(op, op->output_shape(0).size()); return meet ? ret : 0; } bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const { te::Operation cur_op = op; while (cur_op != target_op) { const AccessAnalyzerNode::OperationMap>>& map = operator->()->read_by.at(cur_op); if (map.size() != 1) { return false; } te::Operation next_op = map.begin()->first; // Check condition 1: They have the same output size auto p_cur = cur_op.as(); auto p_next = next_op.as(); if (p_cur == nullptr || p_next == nullptr) { return false; } Array output_shape = p_cur->output_shape(0); for (int i = 1; i < p_cur->num_outputs(); ++i) { if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { return false; } } for (int i = 0; i < p_next->num_outputs(); ++i) { if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { return false; } } // Check condition 2: The read is elementwise const std::vector> reads = map.begin()->second; bool is_simple_access, axis_missing, axis_duplicated, same_order; for (const auto& read : reads) { is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, &axis_duplicated, &same_order); if (!is_simple_access || axis_missing || axis_duplicated || !same_order) { return false; } } cur_op = std::move(next_op); } return true; } // Estimate the number of float operations in an expression class FlopEstimator : public ExprFunctor { public: double EstimateFlop(const Array& ops) { double ret = 0; for (const auto& op : ops) { if (auto pop = op.as()) { if (pop->attrs.count("FLOP")) { // Use user-provided FLOP auto pint = pop->attrs["FLOP"].as(); ICHECK(pint != nullptr); ret += pint->value; } else { // Estimate by parsing the compute body double num_element = AxisLengthProd(pop->axis); if (num_element == -1) { fail_ = true; break; } cur_type_code_ = pop->output_dtype(0).code(); double op_per_element = 0; for (const auto& x : pop->body) { op_per_element += VisitExpr(x); } ret += num_element * op_per_element; } } else if (op->IsInstance()) { {} // do nothing } else { LOG(FATAL) << "Invalid op type " << op; } } return fail_ ? -1 : ret; } double VisitExpr_(const ReduceNode* op) final { uint64_t num_iter = 1; for (const auto& x : op->axis) { if (auto imm = x->dom->extent.as()) { num_iter *= imm->value; } else { fail_ = true; num_iter = -1; } } double body_flop = 0; for (size_t i = 0; i < op->combiner->result.size(); ++i) { body_flop += VisitExpr(op->combiner->result[i]); body_flop += VisitExpr(op->source[i]); } return num_iter * body_flop; } double VisitExpr_(const FloatImmNode* op) final { return 0.0; } double VisitExpr_(const IntImmNode* op) final { return 0.0; } double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } double VisitExpr_(const VarNode* op) final { return 0.0; } double VisitExpr_(const SelectNode* op) final { return VisitExpr(op->condition) + std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } // Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS. #define VisitBinary(Node) \ double VisitExpr_(const Node* op) final { \ double base = 1.0; \ if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \ base = 0.0; \ } \ return base + VisitExpr(op->a) + VisitExpr(op->b); \ } #define VisitUnary(Node) \ double VisitExpr_(const Node* op) final { \ double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ return base + VisitExpr(op->a); \ } VisitBinary(AddNode); VisitBinary(SubNode); VisitBinary(MulNode); VisitBinary(DivNode); VisitBinary(ModNode); VisitBinary(FloorDivNode); VisitBinary(FloorModNode); VisitBinary(MaxNode); VisitBinary(MinNode); VisitBinary(EQNode); VisitBinary(NENode); VisitBinary(LTNode); VisitBinary(LENode); VisitBinary(GTNode); VisitBinary(GENode); VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); double VisitExpr_(const CallNode* op) final { double ret = 0.0; for (const auto& x : op->args) { ret += VisitExpr(x); } return ret; } double VisitExprDefault_(const Object* op) final { fail_ = true; return -1.0; } private: bool fail_{false}; int cur_type_code_; }; void CheckComputeValidity(const te::Schedule& sch) { // Check the validity of a compute definition: // The name of each iterator should be unique. for (auto stage : sch->stages) { if (stage->op->IsInstance()) { std::unordered_set names; for (const auto& x : stage->leaf_iter_vars) { ICHECK(!names.count(x->var->name_hint)) << "Find duplicated iterator names in the compute definition: " << x->var->name_hint << ". Please use different names for different iterators."; names.insert(x->var->name_hint); } } } } ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); node->tensors = std::move(tensors); node->access_analyzer = AccessAnalyzer(node->tensors); Array out_ops; for (const auto& op : node->access_analyzer->ops_topo_order) { if (node->access_analyzer.IsOutput(op)) { out_ops.push_back(op); } } te::Schedule sch = te::create_schedule(out_ops); for (auto stage : sch->stages) { node->ops.push_back(stage->op); } // Make sure it is a valid compute definition CheckComputeValidity(sch); node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); } ComputeDAG::ComputeDAG(const te::Schedule& sch) { auto node = make_object(); // Make sure it is a valid compute definition CheckComputeValidity(sch); // Initialize ops. Here we enforce the order of ops and stages are consistent for (auto stage : sch->stages) { node->ops.push_back(stage->op); } // Collect input and output tensors Array tensors; for (auto stage : sch->stages) { if (stage->op->IsInstance() || stage->is_output) { for (auto i = 0; i < stage->op->num_outputs(); ++i) { tensors.push_back(stage->op.output(i)); } } } node->tensors = std::move(tensors); node->access_analyzer = AccessAnalyzer(node->tensors); node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); } class IndexRewriter : public StmtExprMutator { public: IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout) : placeholder_op_(placeholder_op) { ParseKernelLayout(new_layout, &new_shape_, &new_names_); } PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { te::Tensor t = Downcast(op->producer); if (t->op == placeholder_op_) { std::unordered_map name_to_arg; for (const auto& arg : op->indices) { std::string axis_name; if (const auto* int_imm = arg.as()) { ICHECK_EQ(int_imm->value, 0); axis_name = "IntImm"; } else { axis_name = AxisBaseName(CleanName(Downcast(arg)->name_hint)); ICHECK_EQ(name_to_arg.count(axis_name), 0); name_to_arg[axis_name] = arg; } } std::unordered_map div_factors; std::vector r_new_args; for (int i = new_names_.size() - 1; i >= 0; --i) { auto ori_iter_name = new_names_[i]; auto name_it = name_to_arg.find(ori_iter_name); ICHECK(name_it != name_to_arg.end()); PrimExpr ori_arg = name_it->second; PrimExpr mod_factor = new_shape_[i]; PrimExpr div_factor = 1; if (div_factors.count(ori_iter_name)) { div_factor = div_factors[ori_iter_name]; } div_factors[ori_iter_name] = div_factor * new_shape_[i]; PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); r_new_args.push_back(new_arg); } Array new_args(std::make_move_iterator(r_new_args.rbegin()), std::make_move_iterator(r_new_args.rend())); return ProducerLoad(op->producer, new_args); } return GetRef(op); } private: const te::Operation& placeholder_op_; Array new_shape_; std::vector new_names_; }; std::string GetOrigLayout(std::set* placeholder_axis_names, const te::Operation& op, const te::Tensor& placeholder) { ReadAccessExtractor extractor; for (const auto& exp : op.as()->body) { extractor.Extract(exp); } std::ostringstream os; uint32_t i = 0; const auto& placeholder_op = placeholder->op; ICHECK_GT(extractor.read_access.count(placeholder_op), 0); for (const auto& ev : extractor.read_access[placeholder_op]) { for (const auto& e : ev) { std::string axis_name; if (const auto* int_imm = e.as()) { ICHECK_EQ(int_imm->value, 0); axis_name = "IntImm"; } else { axis_name = AxisBaseName(CleanName(Downcast(e)->name_hint)); } placeholder_axis_names->insert(axis_name); os << placeholder->shape[i++] << axis_name; } } ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size()); std::string orig_layout = os.str(); os.str(""); ::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout); return orig_layout; } std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage, const te::Operation& op, const te::Tensor& placeholder, const std::set& placeholder_axis_names) { std::ostringstream os; Array stage_iters; auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); int attach_pos = -1; size_t iters_before_attach = 0; if (attach_it != state->attach_map->stage_to_attach_iter.end()) { auto attach = attach_it->second; const auto& attach_stage = state->stages[attach.first]; attach_pos = attach.second; stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(), attach_stage->iters.begin() + attach_pos + 1); } stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); std::vector iters; for (size_t i = 0; i < stage_iters.size(); ++i) { const auto& iter = stage_iters[i]; if (iter->orig_iters.empty()) { iters.push_back(iter); } else { for (const Iterator& ori_iter : iter->orig_iters) { iters.push_back(ori_iter); } } if (static_cast(i) == attach_pos) { iters_before_attach = iters.size(); } } std::vector new_names; std::vector new_axis_names; for (const Iterator& iter : iters) { std::set ori_iter_names; ExtractOriginalIterators(iter->name, &ori_iter_names); // fused iters have been replaced with iter->orig_iters. // So there should be only one ori iter name extracted from iter->name. ICHECK_EQ(ori_iter_names.size(), 1); auto ori_iter_name = AxisBaseName(*ori_iter_names.begin()); new_axis_names.push_back(ori_iter_name); } for (size_t i = 0; i < new_axis_names.size(); ++i) { auto iter = iters[i]; std::string ori_iter_name; if (i < iters_before_attach) { ori_iter_name = new_axis_names[i + iters_before_attach]; } else { ori_iter_name = new_axis_names[i]; } if (placeholder_axis_names.count(ori_iter_name)) { PrimExpr extent; if (iter->range.defined()) { extent = iter->range->extent; } else { // This iter is simplified by InferBound, so it must have a length of one. extent = 1; } os << extent << ori_iter_name; new_names.push_back(ori_iter_name); } } std::string new_layout = os.str(); os.str(""); ::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout); return new_layout; } ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const { CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite) << "Call ComputeDAG::RewriteLayout with NoRewrite."; ComputeDAG new_dag = *this; ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); auto node = make_object(); node->transform_steps = *transform_steps; node->concrete = true; const State& state = InferBound(State(node)); OperationSet handled_ops; for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) { const auto& stage = state->stages[stage_id]; const te::Operation& op = stage->op; if (!op->IsInstance()) { continue; } const Map& attrs = op->attrs; if (attrs.count(layout_free_placeholders_key) == 0) { continue; } const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; for (const auto& placeholder : Downcast>(attr_value)) { const auto& placeholder_op = placeholder->op; // Check whether this placeholder has already been handled if (handled_ops.count(placeholder_op)) { continue; } // Skip the op that is not direct consumer of this placeholder. // This is usually caused by cache read/write. bool direct_consumer = false; for (auto& t : op->InputTensors()) { if (t->op == placeholder_op) { direct_consumer = true; break; } } if (!direct_consumer) { continue; } handled_ops.insert(placeholder_op); // Process original layout std::set placeholder_axis_names; std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder); Array origin_shape; std::vector origin_axes; ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); // Process new layout std::string new_layout = GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names); Array new_shape; std::vector new_axes; ParseKernelLayout(new_layout, &new_shape, &new_axes); // Process op updates te::Operation new_op_to_update; if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) { // Create new placeholder new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { // Process index strides std::unordered_map axes_stride; for (const auto& i : origin_axes) { axes_stride[i] = Integer(1); } Array new_stride(new_shape.size(), PrimExpr()); PrimExpr temp = Integer(1); for (int i = new_shape.size() - 1; i >= 0; i--) { new_stride.Set(i, axes_stride[new_axes[i]]); axes_stride[new_axes[i]] *= new_shape[i]; } // Add an extra layout transform stage const auto& layout_transform_tensor = te::compute( new_shape, [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, &new_axes](const tvm::runtime::Array& indices) -> tvm::PrimExpr { Array access_indices; for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { PrimExpr temp = Integer(0); for (size_t i = 0; i < new_shape.size(); i++) { if (origin_axes[indice_index].compare(new_axes[i]) == 0) { temp += indices[i] * new_stride[i]; } } access_indices.push_back(temp); } return placeholder_op.output(0)(access_indices); }, "auto_scheduler_layout_transform"); new_op_to_update = layout_transform_tensor->op; // Update the transform steps for (size_t i = 0; i < transform_steps->size(); i++) { Step step = (*transform_steps)[i]; if (step->stage_id >= static_cast(stage_id)) { step.CopyOnWrite()->stage_id++; } if (step->IsInstance()) { auto compute_at_step = tvm::Downcast(step); if (compute_at_step->target_stage_id >= static_cast(stage_id)) { dynamic_cast(compute_at_step.CopyOnWrite())->target_stage_id++; } transform_steps->Set(i, std::move(compute_at_step)); } else { transform_steps->Set(i, std::move(step)); } } // Add schedule for the new added transform stage Array to_fuse; if (new_shape.size() >= 5) { to_fuse.push_back(0); to_fuse.push_back(1); to_fuse.push_back(2); transform_steps->push_back(FuseStep(stage_id, to_fuse)); } else if (new_shape.size() >= 3) { to_fuse.push_back(0); to_fuse.push_back(1); transform_steps->push_back(FuseStep(stage_id, to_fuse)); } transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); } te::Operation new_compute_op, original_compute_op; Array new_body; IndexRewriter index_rewriter(placeholder_op, new_layout); for (const auto& op : p_dag->ops) { if (auto* pop = op.as()) { bool need_update = false; for (auto& t : op->InputTensors()) { if (t->op == placeholder_op) { need_update = true; break; } } if (need_update) { for (const auto& body : pop->body) { new_body.push_back(index_rewriter.Rewrite(body)); } original_compute_op = op; CHECK(!new_compute_op.defined()); auto new_attrs = pop->attrs; new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout)); new_attrs.Set("new_placeholder_layout", tvm::String(new_layout)); new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body); } } } // construct the map from original_op to new_op std::unordered_map updated_ops; Array original_ops = p_dag->ops; p_dag->ops.clear(); for (size_t i = 0; i < original_ops.size(); ++i) { const auto& original_op = original_ops[i]; if (original_op == placeholder_op) { if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { p_dag->ops.push_back(placeholder_op); } p_dag->ops.push_back(new_op_to_update); updated_ops[placeholder_op] = new_op_to_update; } else if (original_op == original_compute_op) { p_dag->ops.push_back(new_compute_op); updated_ops[original_compute_op] = new_compute_op; } else { p_dag->ops.push_back(original_op); } } ArrayNode* pops = p_dag->ops.CopyOnWrite(); // Because ops is sorted in topo-order, only do one pass linear scan here. for (size_t i = 0; i < pops->size(); ++i) { const auto& original_op = Downcast(pops->at(i)); if (auto* pop = original_op.as()) { if (original_op == new_op_to_update) { continue; } auto inputs = pop->InputTensors(); std::unordered_map rmap; for (auto input : inputs) { auto it = updated_ops.find(input->op); te::Operation new_op; while (it != updated_ops.end()) { new_op = it->second; it = updated_ops.find(new_op); } if (new_op.defined()) { int index = input->value_index; rmap[input] = new_op.output(index); } } if (!rmap.empty()) { te::Operation new_op = pop->ReplaceInputs(original_op, rmap); updated_ops[original_op] = new_op; pops->SetItem(i, new_op); } } } Array old_tensors = p_dag->tensors; ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite(); for (size_t i = 0; i < old_tensors.size(); ++i) { const auto& old_tensor = old_tensors[i]; if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed && old_tensor->op->IsInstance()) { continue; } auto it = updated_ops.find(old_tensor->op); te::Operation new_op; while (it != updated_ops.end()) { new_op = it->second; it = updated_ops.find(new_op); } if (new_op.defined()) { auto index = old_tensor->value_index; p_tensors->SetItem(i, new_op.output(index)); } } } // end for placeholder } // end for stage p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors); Array out_ops; for (const auto& op : p_dag->access_analyzer->ops_topo_order) { if (p_dag->access_analyzer.IsOutput(op)) { out_ops.push_back(op); } } p_dag->ops.clear(); te::Schedule sch = te::create_schedule(out_ops); for (auto stage : sch->stages) { p_dag->ops.push_back(stage->op); } p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); p_dag->init_state = State(p_dag->ops); return new_dag; } // Return whether a DAG has placeholders that are marked as "layout free". bool HasLayoutFreeTensors(const ComputeDAG& dag) { for (const auto& op : dag->ops) { if (!op->IsInstance()) { continue; } if (op->attrs.count(ComputeDAG::layout_free_placeholders_key)) { return true; } } return false; } std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes, LayoutRewriteOption layout_rewrite) const { if (layout_rewrite != LayoutRewriteOption::NoRewrite && HasLayoutFreeTensors(*this) && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); return dag.ApplySteps(steps); } // Temporal object to be used if the input pointer is nullptr Array temp_stages; StageToAxesMap temp_stage_to_axes; if (stages == nullptr) { stages = &temp_stages; } if (stage_to_axes == nullptr) { stage_to_axes = &temp_stage_to_axes; } Array out_ops; for (const auto& op : operator->()->ops) { if (operator->()->access_analyzer.IsOutput(op)) { out_ops.push_back(op); } } // Create the initial schedule te::Schedule schedule = te::create_schedule(out_ops); // init axes for (const auto& x : operator->()->ops) { const te::Stage& stage = schedule[x]; stages->push_back(stage); UpdateStageToAxesMap(stage, stage_to_axes); } // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps); } return std::make_pair(schedule, operator->()->tensors); } String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { Array stages; StageToAxesMap stage_to_axes; Array out_ops; for (const auto& op : operator->()->ops) { if (operator->()->access_analyzer.IsOutput(op)) { out_ops.push_back(op); } } // Create the initial schedule te::Schedule schedule = te::create_schedule(out_ops); // init axes for (const auto& x : operator->()->ops) { const te::Stage& stage = schedule[x]; stages.push_back(stage); UpdateStageToAxesMap(stage, &stage_to_axes); } std::stringstream ss; for (const auto& stage : stages) { if (stage->op->IsInstance()) { auto op_name = CleanName(stage->op->name); for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name); if (i != stage->leaf_iter_vars.size() - 1) { ss << ", "; } } ss << " = " << "tuple(" << op_name << ".op.axis)" << " + " << "tuple(" << op_name << ".op.reduce_axis)\n"; } } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps); } return ss.str(); } String ComputeDAG::PrintDAG(bool simple_mode) const { std::stringstream ss; for (const auto& op : operator->()->ops) { if (op->IsInstance()) { ss << op->name << " = PLACEHOLDER "; if (!simple_mode) { ss << op.output(0)->shape; } ss << "\n"; } else if (auto pop = op.as()) { for (size_t k = 0; k < pop->body.size(); ++k) { ss << op->name << "("; for (size_t i = 0; i < pop->axis.size(); i++) { ss << pop->axis[i]->var->name_hint; if (i != pop->axis.size() - 1) { ss << ", "; } } ss << ")"; if (pop->body.size() > 1) { ss << ".v" << k; } if (auto preduce = pop->body[k].as()) { ICHECK_LT(k, preduce->combiner->result.size()); PrimExpr combiner = preduce->combiner->result[k]; if (combiner->IsInstance()) { ss << " += " << preduce->source[0] << "\n"; } else if (combiner->IsInstance()) { ss << " max= " << preduce->source[0] << "\n"; } else if (combiner->IsInstance()) { ss << " min= " << preduce->source[0] << "\n"; } else if (combiner->IsInstance()) { const auto& select = combiner.as(); ss << " select(" << select->condition << ", " << select->true_value << ", " << select->false_value << ")= " << '(' << preduce->source[0] << ',' << preduce->source[1] << ")\n"; } else { ss << "reduce" << combiner << "\n"; } } else { auto call = pop->body[k].as(); if (simple_mode && call) { ss << " = " << call->op << "\n"; } else { ss << " = " << pop->body[k] << "\n"; } } } } else { LOG(FATAL) << "Invalid op"; } } return String(ss.str()); } State ComputeDAG::InferBound(const State& state) const { ICHECK(state->concrete) << "Only concrete state can be processed to get bound info."; State ret_state; StateNode* pstate; if (state->stages.empty()) { // If the input state is incomplete with empty operation stage // create a new state from init_state and update it first ret_state = operator->()->init_state; pstate = ret_state.CopyOnWrite(); pstate->transform_steps = state->transform_steps; for (const auto& step : pstate->transform_steps) { StepApplyToState(step, &ret_state, *this); } } else { ret_state = state; pstate = ret_state.CopyOnWrite(); } Array stages; StageToAxesMap stage_to_axes; te::Schedule sch; Array tensors; // Replay steps to tvm::Schedule std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize_for_feature_extraction(); // Get bound information from TVM schedule Map bounds = te::InferBound(sch); // Update the state bound information for (size_t i = 0; i < pstate->stages.size(); ++i) { const Stage& stage = pstate->stages[i]; if (stage->compute_at == ComputeAtKind::kInlined) { continue; } Array new_iters; new_iters.reserve(stage->iters.size()); // Get bound information from schedule // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result for (size_t j = 0; j < stage->iters.size(); ++j) { const Iterator& iter = stage->iters[j]; const IterVar& axis = stage_to_axes.at(stages[i])[j]; auto find_res = bounds.find(axis); if (find_res != bounds.end()) { new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_kind, iter->annotation, &iter->orig_iters)); } else { LOG(FATAL) << "Infer bound fails"; } } pstate->stages.Set( i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); } return ret_state; } Array ComputeDAG::InferBound(const Array& states) const { Array out_states(states.size(), State()); support::parallel_for(0, states.size(), [this, &states, &out_states](int i) { try { out_states.Set(i, (states[i].defined()) ? this->InferBound(states[i]) : states[i]); } catch (Error& e) { LOG(WARNING) << "InferBound fails on the state:\n" << states[i] << "\n" << "with: " << e.what() << std::endl; } }); return out_states; } ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const { te::Schedule sch; Array old_tensors; std::tie(sch, old_tensors) = ApplySteps(transform_steps); return ComputeDAG(sch); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); for (const auto& op : node->ops_topo_order) { p->stream << op << std::endl; p->stream << "is_simple_access:\t" << node->is_simple_access.at(op) << "\t\t"; p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op) << std::endl; p->stream << "is_strictly_inlinable:\t" << node->is_strictly_inlineable.at(op) << "\t"; p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; p->stream << "Read from:\t"; for (const auto& pair : node->read_from.at(op)) { for (const auto& index : pair.second) { p->stream << pair.first->name << Array(index) << ", "; } } p->stream << std::endl; p->stream << "Read by:\t"; for (const auto& pair : node->read_by.at(op)) { for (const auto& index : pair.second) { p->stream << pair.first->name << Array(index) << ", "; } } p->stream << std::endl; p->stream << Chars('=', 50) << std::endl; } AccessAnalyzer ana = GetRef(node); p->stream << "ElementwiseMatch: \n"; for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { if (i == j) { continue; } if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { p->stream << node->ops_topo_order[i]->name << " -> " << node->ops_topo_order[j]->name << std::endl; } } } p->stream << Chars('=', 50) << std::endl; p->stream << "NumCommonOuterIterators: \n"; for (const auto& src_pair : node->num_common_outer_iterators) { for (const auto& dst_pair : src_pair.second) { p->stream << src_pair.first->name << " " << dst_pair.first->name << " " << dst_pair.second << std::endl; } } p->stream << Chars('=', 50) << std::endl; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); auto dag = GetRef(node); auto dag_str = dag.PrintDAG(); p->stream << dag_str; }); Array GetShapeFromRewrittenLayout(String rewritten_layout, Array axis_names) { Array shape; std::vector extracted_names; topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names); Array ret(axis_names.size(), 1); size_t ct = 0; for (size_t i = 0; i < axis_names.size(); ++i) { for (size_t j = 0; j < extracted_names.size(); ++j) { if (axis_names[i] == extracted_names[j]) { ret.Set(i, ret[i] * shape[j]); ct++; } } } CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match"; return ret; } TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG") .set_body_typed([](Optional> tensors, Optional sch) { if (sch) { return ComputeDAG(sch.value()); } ICHECK(tensors) << "Both tensors and schedule are null"; return ComputeDAG(tensors.value()); }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, nullptr, nullptr, static_cast(layout_rewrite)); return Array{sch, return_tensors}; }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { return dag.PrintStepsAsPython(state->transform_steps); }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintDAG") .set_body_typed([](const ComputeDAG& dag, bool simple_mode) { return dag.PrintDAG(simple_mode); }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { return dag.InferBound(state); }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { Array* transform_steps = const_cast*>(&state->transform_steps); return dag.RewriteLayout(transform_steps, LayoutRewriteOption::RewriteForPreTransformed); }); TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout") .set_body_typed([](const te::Operation& placeholder_op, const std::string& new_layout, const PrimExpr& body) { IndexRewriter index_rewriter(placeholder_op, new_layout); return index_rewriter.Rewrite(body); }); TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout") .set_body_typed(GetShapeFromRewrittenLayout); } // namespace auto_scheduler } // namespace tvm