/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file auto_scheduler/loop_state.cc * \brief An lightweight IR (intermediate representation) for loop structures. * see auto_scheduler/loop_state.h for more explanation. */ #include #include #include #include #include #include #include "utils.h" namespace tvm { namespace auto_scheduler { TVM_REGISTER_OBJECT_TYPE(StepNode); TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); /********** Iterator **********/ Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation, const std::vector* orig_iters) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); node->iter_kind = iter_kind; node->annotation = annotation; if (orig_iters != nullptr) { node->orig_iters = *orig_iters; } data_ = std::move(node); } /********** Stage **********/ Stage::Stage(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = StageKind::kCompute; auto* pop = op.as(); for (const auto& axis : pop->axis) { node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, IteratorKind::kSpatial, IteratorAnnotation::kNone)); } for (const auto& axis : pop->reduce_axis) { node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, IteratorKind::kReduction, IteratorAnnotation::kNone)); } } else if (op->IsInstance()) { node->op_type = StageKind::kPlaceholder; } else { LOG(FATAL) << "Unsupported operator type" << op->_type_key; } node->compute_at = ComputeAtKind::kRoot; node->op = std::move(op); node->attrs.auto_unroll_max_step = 0; node->attrs.storage_offset = 0; data_ = std::move(node); } Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, ComputeAtKind compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; node->attrs = attrs; data_ = std::move(node); } /********** AttachMap **********/ void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); // Delete the current entry of this stage DeleteStageEntry(pnode, stage_id); // Store the new stage/iterator relations to map IterKey iter_key(target_stage_id, target_iter_id); pnode->stage_to_attach_iter[stage_id] = iter_key; pnode->iter_to_attached_stages[iter_key].push_back(stage_id); } void AttachMap::DeleteStage(int stage_id) { AttachMapNode* pnode = CopyOnWrite(); // Delete the original stage entry DeleteStageEntry(pnode, stage_id); } void AttachMap::UpdateIters(const std::vector& original_iters, const std::vector& new_iters) { ICHECK_EQ(original_iters.size(), new_iters.size()); AttachMapNode* pnode = CopyOnWrite(); std::unordered_map> new_iter_to_attached_stages; for (size_t i = 0; i < original_iters.size(); ++i) { auto entry = pnode->iter_to_attached_stages.find(original_iters[i]); // We get > from this map if (entry == pnode->iter_to_attached_stages.end()) { // Skip if this iterator does not have any attach relations continue; } // Update the attaching target of an stage to the new iter in `stage_to_attach_iter` for (const auto& s : entry->second) { pnode->stage_to_attach_iter[s] = new_iters[i]; } // Remove the original iterator relation from `iter_to_attached_stages` and add the new // iterator to it std::vector attached_stages = std::move(entry->second); pnode->iter_to_attached_stages.erase(entry); new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); } // Update new entries for (auto& it : new_iter_to_attached_stages) { pnode->iter_to_attached_stages[it.first] = std::move(it.second); } } void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { auto old_entry = pnode->stage_to_attach_iter.find(stage_id); // We get from this map if (old_entry != pnode->stage_to_attach_iter.end()) { // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have // any attached stage, delete this iterm too auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); // We get > from this map FindAndDeleteItem(&entry2->second, stage_id); if (entry2->second.size() == 0) { pnode->iter_to_attached_stages.erase(entry2); } // Delete the stage in `stage_to_attach_iter` pnode->stage_to_attach_iter.erase(old_entry); } } AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const { AttachMap map = AttachMap(make_object()); auto pmap = map.CopyOnWrite(); for (const auto& x : operator->()->stage_to_attach_iter) { auto key = x.first; if (key >= start_id) { key += offset; } auto value = x.second; if (value.first >= start_id) { value.first += offset; } pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); } for (const auto& x : operator->()->iter_to_attached_stages) { auto key = x.first; if (key.first >= start_id) { key.first += offset; } auto value = x.second; for (auto& i : value) { if (i >= start_id) { i += offset; } } pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); } return map; } /********** State **********/ State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { node->stages.push_back(Stage(op)); } node->attach_map = AttachMap(make_object()); node->concrete = true; data_ = std::move(node); } /********** Schedule primitives apis for state **********/ Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " << "kThreadX, kThreadY, kBlockZ, kThreadZ"; } AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; // Don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { if (auto imm = it->range->extent.as()) { if (imm->value > max_unroll) { return it; } } } AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; Array indices; GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; ICHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " << "should be specified"; Array after_ids; GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); step->ApplyToState(this); } Array State::split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; FollowSplitStep step = FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } Array State::follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); step->ApplyToState(this); } void State::compute_inline(int stage_id) { ComputeInlineStep step = ComputeInlineStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); step->ApplyToState(this); } void State::compute_root(int stage_id) { ComputeRootStep step = ComputeRootStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); step->ApplyToState(this); } int State::cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, const ComputeDAG& dag) { CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this, dag); } int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) { CacheWriteStep step = CacheWriteStep(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this, dag); } int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) { const Stage& stage = operator->()->stages[stage_id]; RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this, dag); } // Print stage to ostream void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; if (stage->attrs.auto_unroll_max_step != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n"; } size_t indent = 0; for (size_t i = 0; i < stage->iters.size(); ++i) { const Iterator& iter = stage->iters[i]; if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } *os << IteratorAnnotationString[static_cast(iter->annotation)] << " "; if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")"; } else { *os << iter->name << " (None)"; } *os << "\n"; indent += 2; } if (state.defined()) { IterKey iter_key(stage_id, i); auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); if (pair != state->attach_map->iter_to_attached_stages.end()) { // Print the attached stage for (const auto& attach_stage_id : pair->second) { PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); } } } } for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } *os << stage->op->name << " = ...\n"; } // Print state to ostream void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) { // Gather placeholders Array placeholders; for (const auto& stage : state->stages) { if (stage->op_type == StageKind::kPlaceholder) { placeholders.push_back(stage->op->name); } } *os << "Placeholder: "; for (size_t i = 0; i < placeholders.size(); ++i) { *os << placeholders[i]; if (i != placeholders.size() - 1) { *os << ", "; } } *os << "\n"; // Print all stages for (size_t i = 0; i < state->stages.size(); ++i) { const Stage& stage = state->stages[i]; if (stage->op_type == StageKind::kPlaceholder) { continue; } else if (stage->op_type == StageKind::kCompute) { if (stage->compute_at == ComputeAtKind::kRoot) { PrintStage(os, i, state, 0, delete_trivial_loop); } } else { LOG(FATAL) << "Invalid op type"; } } } String State::ToStr(bool delete_trivial_loop) const { std::ostringstream os; PrintState(&os, (*this), delete_trivial_loop); return os.str(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { const auto& stage = tvm::Downcast(ref); p->stream << stage->GetTypeKey() << "(" << stage.get() << ": " << stage->op->name << ")"; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { PrintState(&p->stream, tvm::Downcast(ref), true); }); /********** State interface API for ffi **********/ TVM_REGISTER_GLOBAL("auto_scheduler.StateBind") .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") .set_body_typed([](State state, int stage_id, const Iterator& it) { const auto& res = state.parallel(stage_id, it); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { const auto& res = state.unroll(stage_id, it, max_unroll); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") .set_body_typed([](State state, int stage_id, const Iterator& it) { const auto& res = state.vectorize(stage_id, it); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma") .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) { state.pragma(stage_id, it, pragma_type); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { state.reorder(stage_id, order); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer) { const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array{state, res}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, int n_split) { const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts) { const auto& res = state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts); return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign") .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) { state.storage_align(stage_id, it, factor, offset); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { state.compute_at(stage_id, target_stage_id, target_iter); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline") .set_body_typed([](State state, int stage_id) { state.compute_inline(stage_id); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") .set_body_typed([](State state, int stage_id) { state.compute_root(stage_id); return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead") .set_body_typed([](State state, int stage_id, const String& scope_name, const Array& reader_stage_ids, const ComputeDAG& dag) { int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag); return Array{state, Integer(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite") .set_body_typed([](State state, int stage_id, const String& scope_name, const ComputeDAG& task_dag) { int res = state.cache_write(stage_id, scope_name, task_dag); return Array{state, Integer(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor") .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) { int res = state.rfactor(stage_id, it, factor_iter_id, dag); return Array{state, Integer(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); } // namespace auto_scheduler } // namespace tvm