/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file auto_scheduler/search_policy/utils.h * \brief Common utilities for search policies. */ #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ #include <dmlc/common.h> #include <tvm/auto_scheduler/loop_state.h> #include <tvm/auto_scheduler/search_policy.h> #include <tvm/ir/expr.h> #include <tvm/te/operation.h> #include <algorithm> #include <condition_variable> #include <set> #include <string> #include <tuple> #include <unordered_map> #include <unordered_set> #include <utility> #include <vector> #include "../utils.h" namespace tvm { namespace auto_scheduler { /*! \brief Return whether the search task is targeting a CPU. */ inline bool IsCPUTask(const SearchTask& task) { return (task)->target->kind->device_type == kDLCPU; } /*! \brief Return whether the search task is targeting a GPU. */ inline bool IsGPUTask(const SearchTask& task) { return (task)->target->kind->device_type == kDLCUDA || (task)->target->kind->device_type == kDLOpenCL || (task)->target->kind->device_type == kDLVulkan || (task)->target->kind->device_type == kDLMetal || (task)->target->kind->device_type == kDLROCM || (task)->target->kind->device_type == kOpenGL; } /*! \brief Return whether the search task is targeting a CUDA GPU. */ inline bool IsCUDATask(const SearchTask& task) { return (task)->target->kind->device_type == kDLCUDA; } /*! \brief Return whether the search task is targeting a OpenCL GPU. */ inline bool IsOpenCLTask(const SearchTask& task) { return (task)->target->kind->device_type == kDLOpenCL; } /*! \brief Argsort. Order: largest to smallest */ template <typename T> inline std::vector<int> Argsort(const std::vector<T>& scores) { std::vector<int> index; index.reserve(scores.size()); for (size_t i = 0; i < scores.size(); ++i) { index.push_back(i); } auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; }; std::sort(index.begin(), index.end(), cmp); return index; } /*! \brief Convert operation to stage id. */ inline int OperationToStage(const te::Operation& op, const State& state) { for (size_t i = 0; i < state->stages.size(); ++i) { if (op == state->stages[i]->op) { return i; } } LOG(FATAL) << "Cannot find op: " << op; return -1; } /********** Get Parameters **********/ /*! \brief Get an integer from a tvm str Map. */ inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pint = attr_dict[key].as<IntImmNode>(); ICHECK(pint != nullptr); return pint->value; } /*! \brief Get a double from a tvm str Map. */ inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pdouble = attr_dict[key].as<FloatImmNode>(); ICHECK(pdouble != nullptr); return pdouble->value; } /*! \brief Get a string from a tvm str Map. */ inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; const auto& target = attr_dict[key]; if (auto pstr = target.as<StringImmNode>()) { return pstr->value; } auto pstr = target.as<StringObj>(); ICHECK(pstr != nullptr); return pstr->data; } /*! \brief Get a iterator name set from a tvm str Map. */ inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) { std::set<std::string> ret; ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto names = attr_dict[key].as<ArrayNode>(); ICHECK(names != nullptr); for (const auto& name : *names) { ret.insert(name.as<StringObj>()->data); } return ret; } /********** Checks with ComputeDAG **********/ /*! \brief Return whether an op is strictly-inlineable. */ inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) { if (state->current_compute_dag) { return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable( state->stages[stage_id]->op); } else { return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op); } } /*! \brief Return whether an op is an output op. */ inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) { if (state->current_compute_dag) { return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput( state->stages[stage_id]->op); } else { return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op); } } /*! \brief Return whether an op needs multi level tiling. */ inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) { if (state->current_compute_dag) { return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling( state->stages[stage_id]->op); } else { return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op); } } /*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */ inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) { std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers; std::set<int> ret; if (state->current_compute_dag) { consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers( state, state->stages[stage_id]->op); } else { consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op); } for (const auto& op : consumers) { ret.insert(OperationToStage(op, state)); } return ret; } /*! \brief Check if a stage has single consumer or all of its consumers share a common root, return * the target consumer root or -1. */ inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) { const std::set<int>& consumers = GetConsumers(task, state, stage_id); if (consumers.empty()) { return -1; } if (consumers.size() == 1) { return *consumers.begin(); } else { // Check all consumers share a common root int common_root_id = -1; bool mismatch = false; for (const auto& consumer_stage_id : consumers) { int root_id = -1; if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) { root_id = consumer_stage_id; } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) { root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; } else { LOG(FATAL) << "Invalid case"; } if (common_root_id == -1) { common_root_id = root_id; } else { if (common_root_id != root_id) { mismatch = true; break; } } } return mismatch ? -1 : common_root_id; } } /*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */ inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) { std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers; std::set<int> ret; if (state->current_compute_dag) { producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers( state, state->stages[stage_id]->op); } else { producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op); } for (const auto& op : producers) { ret.insert(OperationToStage(op, state)); } return ret; } /*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for * inlined ops. */ inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) { std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers; std::set<int> ret; if (state->current_compute_dag) { producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers( state->stages[stage_id]->op); } else { producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op); } for (const auto& op : producers) { ret.insert(OperationToStage(op, state)); } return ret; } /*! \brief Get the number of common outer iterators. This function propagates the relation for * chains with multiple ops. */ inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id, int target_stage_id) { if (state->current_compute_dag) { return state->current_compute_dag.as<ComputeDAGNode>() ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op, state->stages[target_stage_id]->op); } else { return task->compute_dag->access_analyzer.GetNumCommonOuterIterator( state->stages[stage_id]->op, state->stages[target_stage_id]->op); } } /*! \brief Return whether two ops are elementwise-matched. */ inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id, int target_stage_id) { const auto& op = state->stages[stage_id]->op; const auto& target_op = state->stages[target_stage_id]->op; if (state->current_compute_dag) { return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch( op, target_op); } else { return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); } } /********** Get informations from Stage/Iterator **********/ /*! \brief Return the extent of an iterator. */ inline int64_t GetExtent(const Iterator& it) { if (it->range.defined()) { if (auto pint = it->range->extent.as<IntImmNode>()) { return pint->value; } } return -1; } /*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */ inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) { int64_t cum_space_len = 1, cum_reduce_len = 1; for (const auto& iter : stage->iters) { if (iter->iter_kind == IteratorKind::kSpatial) { cum_space_len *= GetExtent(iter); } else if (iter->iter_kind == IteratorKind::kReduction) { cum_reduce_len *= GetExtent(iter); } } return std::make_pair(cum_space_len, cum_reduce_len); } /*! \brief Return whether this stage needs rfactor. */ inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) { const auto& op = state->stages[stage_id]->op; if (op->IsInstance<te::ComputeOpNode>()) { // Compute the product of lengths of all space iters and all reduce iters int cum_space_len, cum_reduce_len; std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(state->stages[stage_id]); if (NeedsMultilevelTiling(task, state, stage_id)) { // Do not use rfactor if we have enough parallelism on space iters if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) { return false; } else { return true; } } else if (cum_reduce_len > 1) { // Always try rfactor for reduction ops return cum_reduce_len > task->hardware_params->num_cores; } } return false; } /*! \brief Return whether the stage has reduce iterators. */ inline bool HasReduceIter(const Stage& stage) { for (const auto& iter : stage->iters) { if (iter->iter_kind != IteratorKind::kSpatial) { return true; } } return false; } /*! \brief Return whether the stage has specific annotated iterators. */ inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) { for (const auto& iter : stage->iters) { if (iter->annotation == type) { return true; } } return false; } /*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */ inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state, int stage_id, int* target_stage_id = nullptr) { // Temporal object to be used if the input pointer is nullptr int temp_target_stage_id; if (target_stage_id == nullptr) { target_stage_id = &temp_target_stage_id; } const std::set<int>& consumers = GetConsumers(task, state, stage_id); if (consumers.size() == 1) { *target_stage_id = *consumers.begin(); if (ElementwiseMatch(task, state, stage_id, *target_stage_id) && (!(HasReduceIter(state->stages[stage_id]) && HasReduceIter(state->stages[*target_stage_id]))) && (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) { return true; } } return false; } /*! \brief Return whether the step changes the number of stages */ inline bool IsStageNumberChangingStep(const Step& step) { return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() || step->IsInstance<RfactorStepNode>(); } /*! \brief Return whether the state does cache_read for stage_id. */ inline bool HasCacheReadStage(const State& s, int stage_id) { for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) { if (stage_id == ps->stage_id) { return true; } } if (IsStageNumberChangingStep(s->transform_steps[i])) { if (stage_id > s->transform_steps[i]->stage_id) { stage_id--; } } } return false; } /*! \brief Return whether the state does cache_write for stage_id. */ inline bool HasCacheWriteStage(const State& s, int stage_id) { for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) { if (stage_id == ps->stage_id) { return true; } } if (IsStageNumberChangingStep(s->transform_steps[i])) { if (stage_id > s->transform_steps[i]->stage_id) { stage_id--; } } } return false; } /*! \brief Return whether the state does rfactor for stage_id. */ inline bool HasRfactorStage(const State& s, int stage_id) { for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) { if (stage_id == ps->stage_id) { return true; } } if (IsStageNumberChangingStep(s->transform_steps[i])) { if (stage_id > s->transform_steps[i]->stage_id) { stage_id--; } } } return false; } /*! \brief Return whether the stage does cross thread reduction. */ inline bool HasCrossThreadReduction(const State& state, int stage_id) { std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) { for (const auto& iter : in_stage->iters) { if (iter->annotation == IteratorAnnotation::kThreadX && iter->iter_kind == IteratorKind::kReduction) { return true; } } return false; }; // Check the stage itself if (check_stage(state->stages[stage_id])) { return true; } // Check the attached stages for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) { const auto& res = state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); if (res != state->attach_map->iter_to_attached_stages.end()) { for (int attached_stage_id : res->second) { if (check_stage(state->stages[attached_stage_id])) { return true; } } } } return false; } /*! \brief Return whether the stage has been tiled already. */ inline bool IsTiled(const Stage& stage) { auto op = stage->op.as<te::ComputeOpNode>(); ICHECK(op != nullptr); return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); } /*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */ inline void ExtractOriginalIterators(const std::string& name, std::set<std::string>* rets) { size_t last_pos = 0; for (size_t i = 0; i < name.size(); ++i) { if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { rets->insert(name.substr(last_pos, i - last_pos)); } last_pos = i + 1; } } if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { rets->insert(name.substr(last_pos, name.size() - last_pos)); } } /*! \brief Get the last reduce iterator in the outermost reduce tile. */ inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { auto pop = stage->op.as<te::ComputeOpNode>(); ICHECK(pop != nullptr); std::set<std::string> original_names; const std::set<std::string>& no_split_at_inner_name_set = stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) : std::set<std::string>(); size_t reduce_axis_size = 0; for (const auto axis : pop->reduce_axis) { if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { reduce_axis_size++; } } if (reduce_axis_size) { for (const auto& iter : stage->iters) { if (iter->iter_kind == IteratorKind::kReduction) { ExtractOriginalIterators(iter->name, &original_names); if (original_names.size() == reduce_axis_size) { return iter; } } } } else { // Return the first reduce iterator for (const auto& iter : stage->iters) { if (iter->iter_kind == IteratorKind::kReduction) { return iter; } } } LOG(FATAL) << "Cannot find the iterator."; return stage->iters[0]; } /*! \brief Get the target stage id of a history step in the new state. * We need this because the stage_id in the history may be stale due to later steps */ inline int GetTargetStageIDInState(const State& s, int step_id) { int stage_inc = 0; for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) { if (IsStageNumberChangingStep(s->transform_steps[i])) { if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc) stage_inc++; } } return s->transform_steps[step_id]->stage_id + stage_inc; } /*! \brief Get all split steps for one stage. */ inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) { for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as<SplitStepNode>()) { if (stage_id == ps->stage_id) { split_step_ids->push_back(i); } } if (IsStageNumberChangingStep(s->transform_steps[i])) { if (stage_id > s->transform_steps[i]->stage_id) { stage_id--; } } } } /*! \brief Fuse all reduction iterators. */ inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter, Array<Iterator>* space_iters, Array<Iterator>* reduce_iters) { space_iters->clear(); reduce_iters->clear(); for (const auto& iter : state->stages[stage_id]->iters) { if (iter->iter_kind == IteratorKind::kSpatial) { space_iters->push_back(iter); } else if (iter->iter_kind == IteratorKind::kReduction) { reduce_iters->push_back(iter); } } ICHECK(!reduce_iters->empty()); State tmp_s = state; if (reduce_iters->size() > 1) { *fused_iter = tmp_s.fuse(stage_id, *reduce_iters); } else { *fused_iter = (*reduce_iters)[0]; } return tmp_s; } /*! \brief Fuse all outer level space iterators. */ inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) { std::vector<Iterator> to_fuse; for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) { const auto& it = state->stages[stage_id]->iters[iter_id]; // Stop at reduce iterator or annotated iterator if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { break; } // Stop at compute_at attach point if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) { break; } to_fuse.push_back(it); } State tmp_s = state; if (to_fuse.size() == 1) { *fused_iter = to_fuse[0]; } else { *fused_iter = tmp_s.fuse(stage_id, to_fuse); } return tmp_s; } /*! \brief Random sample states. */ inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen, size_t out_size) { Array<State> out_states; for (size_t i = 0; i < out_size; i++) { out_states.push_back(in_states[(*random_gen)() % in_states.size()]); } return out_states; } /*! \brief Compute prefix-sum probabiilty based on the given weights */ inline void ComputePrefixSumProb(const std::vector<float>& weights, std::vector<double>* prefix_sum_probs) { // Compute selection probabilities. float sum = 0.0; prefix_sum_probs->resize(weights.size()); for (size_t i = 0; i < weights.size(); ++i) { sum += std::max(weights[i], 0.0f); (*prefix_sum_probs)[i] = sum; } for (size_t i = 0; i < weights.size(); ++i) { (*prefix_sum_probs)[i] /= sum; } } /*! \brief Random choose an index according to a prefix sum probability. */ inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) { std::uniform_real_distribution<> dis(0.0, 1.0); double x = dis(*random_gen); ICHECK(!prefix_sum_probs.empty()); return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - prefix_sum_probs.begin(); } /*! \brief Print a title */ inline void PrintTitle(const std::string& title, int verbose) { StdCout(verbose) << Chars('-', 70) << "\n" << Chars('-', 30) << " [ " << title << " ]\n" << Chars('-', 70) << std::endl; } /*! * \brief Enumerate all possible factorization schemes for splitting an axes. * \note This class will memorize the results for reuse. */ class SplitFactorizationMemo { public: using QueryKey = std::tuple<int, int, int>; const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths, int max_innermost_factor); const std::vector<int>& GetFactors(int n); private: void DfsEnumerate(int now, int remaining_length, int max_innermost_factor); std::unordered_map<QueryKey, Array<Array<Integer>>> memory_; int n_lengths_; Array<Integer> tmp_stack_; Array<Array<Integer>>* results_; std::unordered_map<int, std::vector<int>> factor_memory_; }; /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */ Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id); /*! \brief Get the possible compute locations for a stage. */ std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task, const State& state, int stage_id); // Apply multi-level tiling structure according to a string format, // where "S" stands a space level, "R" stands for a reduction level. // For example, if the format is "SSRSRS", then we will // use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 // For example, if apply "SSRSRS" to matrix multiplication, // we have space iterators i and j, reduce iterator k. // Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector<int>* spatial_split_step_ids = nullptr); // Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids, int n_split); // Prune invalid states and return the results in-place. void PruneInvalidState(const SearchTask& task, Array<State>* states); } // namespace auto_scheduler } // namespace tvm #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_