/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "./utils.h" namespace tvm { namespace tir { /**************** Constructors ****************/ Trace::Trace() { data_ = make_object<TraceNode>(); } Trace::Trace(Array<Instruction> insts, Map<Instruction, ObjectRef> decisions) { ObjectPtr<TraceNode> n = make_object<TraceNode>(); n->insts = std::move(insts); n->decisions = std::move(decisions); data_ = std::move(n); } /**************** Utilities ****************/ bool IsPostproc(const InstructionKind& inst_kind) { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); return inst_kind.same_as(inst_enter_postproc); } int GetNumValidInstructions(const Array<Instruction>& insts, bool remove_postproc) { if (!remove_postproc) { return insts.size(); } int n_insts = 0; for (const Instruction& inst : insts) { if (!IsPostproc(inst->kind)) { ++n_insts; } else { break; } } return n_insts; } /**************** TranslateInputRVs ****************/ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs, const std::unordered_map<const Object*, const Object*>& rv_map) { Array<ObjectRef> result; result.reserve(inputs.size()); for (const ObjectRef& input : inputs) { if (!input.defined() || // constant: nullptr input->IsInstance<StringObj>() || // constant: string input->IsInstance<IntImmNode>() || // constant: integer input->IsInstance<FloatImmNode>()) { // constant: float result.push_back(input); } else if (input->IsInstance<BlockRVNode>() || // RV: block input->IsInstance<LoopRVNode>() || // RV: loop input->IsInstance<VarNode>()) { // RV: var auto it = rv_map.find(input.get()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; result.push_back(GetRef<ObjectRef>(it->second)); } else if (const auto* expr = input.as<PrimExprNode>()) { // RV: Expr result.push_back( Substitute(GetRef<PrimExpr>(expr), [&rv_map](const Var& var) -> Optional<PrimExpr> { auto it = rv_map.find(var.get()); if (it == rv_map.end()) { return NullOpt; } const Object* dst = it->second; ICHECK(dst->IsInstance<VarNode>()) << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); return GetRef<Var>(static_cast<const VarNode*>(dst)); })); } else { ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " << input->GetTypeKey(); throw; } } return result; } Array<ObjectRef> TranslateInputRVs( const Array<ObjectRef>& inputs, const std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual>& rv_names) { Array<ObjectRef> results; results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { if (!input.defined()) { // Case 0. nullptr => None results.push_back(String("None")); continue; } auto it = rv_names.find(input); if (it != rv_names.end()) { // Case 1. BlockRV, LoopRV, VarRV results.push_back(it->second); } else if (const auto* str_obj = input.as<StringObj>()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); } else if (input->IsInstance<IntImmNode>() || input->IsInstance<FloatImmNode>()) { // Case 3. integer or floating-point number results.push_back(input); } else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() || inputs->IsInstance<VarNode>()) { LOG(FATAL) << "IndexError: Random variable is not defined " << input; throw; } else { LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input->GetTypeKey(); throw; } } return results; } Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs, const std::unordered_map<std::string, ObjectRef>& named_rvs) { Array<ObjectRef> results; results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { // Case 3. integer or floating-point number if (input->IsInstance<IntImmNode>() || input->IsInstance<FloatImmNode>()) { results.push_back(input); continue; } const auto* str = input.as<StringObj>(); CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; const char* name = str->data; int64_t size = str->size; // Case 2. string if (size > 2 && name[0] == '"' && name[size - 1] == '"') { results.push_back(String(std::string(name + 1, size - 2))); continue; } // Case 0 & 1. None, BlockRV, LoopRV, VarRV auto it = named_rvs.find(name); CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; results.push_back(it->second); } return results; } /**************** TranslateAddOutputRVs ****************/ void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs, std::unordered_map<const Object*, const Object*>* rv_map) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); for (int i = 0; i < n; ++i) { (*rv_map)[p_old[i].get()] = p_new[i].get(); } } Array<String> TranslateAddOutputRVs( const Array<ObjectRef>& outputs, std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual>* rv_names) { Array<String> results; results.reserve(outputs.size()); for (const ObjectRef& output : outputs) { int i = rv_names->size(); ICHECK(!rv_names->count(output)) << "ValueError: The random variable has been produced once: " << rv_names->at(output); String result{ObjectPtr<StringObj>{nullptr}}; if (output->IsInstance<BlockRVNode>()) { result = "b" + std::to_string(i); } else if (output->IsInstance<LoopRVNode>()) { result = "l" + std::to_string(i); } else if (output->IsInstance<VarNode>()) { result = "v" + std::to_string(i); } else { LOG(FATAL) << "TypeError: Cannot recognize the type of the random variable: " << output->GetTypeKey(); throw; } results.push_back(result); rv_names->emplace(output, std::move(result)); } return results; } void TranslateAddOutputRVs(const Array<String>& old_outputs, const Array<ObjectRef>& new_outputs, std::unordered_map<std::string, ObjectRef>* named_rvs) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); for (int i = 0; i < n; ++i) { const auto* name = static_cast<const StringObj*>(p_old[i].get()); named_rvs->emplace(std::string(name->data, name->size), p_new[i]); } } /**************** Add/Remove/Get ****************/ Optional<ObjectRef> TraceNode::GetDecision(const Instruction& inst) const { auto it = this->decisions.find(inst); return it == this->decisions.end() ? Optional<ObjectRef>(NullOpt) : (*it).second; } void TraceNode::Append(Instruction inst) { insts.push_back(std::move(inst)); } void TraceNode::Append(Instruction inst, ObjectRef decision) { decisions.Set(inst, std::move(decision)); insts.push_back(std::move(inst)); } Optional<Instruction> TraceNode::Pop() { if (insts.empty()) { return NullOpt; } Instruction inst = insts.back(); insts.pop_back(); if (decisions.count(inst)) { decisions.erase(inst); } return inst; } /**************** Interfacing with InstructionKind ****************/ void TraceNode::ApplyToSchedule( Schedule sch, bool remove_postproc, runtime::TypedPackedFunc<ObjectRef(const Instruction& inst, const Array<ObjectRef>& inputs, // const Array<ObjectRef>& attrs, // const Optional<ObjectRef>& decision)> decision_provider) const { std::unordered_map<const Object*, const Object*> rv_map; for (const Instruction& inst : this->insts) { if (remove_postproc && IsPostproc(inst->kind)) { break; } Array<ObjectRef> inputs = TranslateInputRVs(inst->inputs, rv_map); Array<ObjectRef> attrs = inst->attrs; Optional<ObjectRef> decision = this->GetDecision(inst); if (decision_provider != nullptr) { decision = decision_provider(inst, inputs, attrs, decision); } Array<ObjectRef> outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } } ObjectRef TraceNode::AsJSON(bool remove_postproc) const { std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> rv_names; Array<ObjectRef> json_insts; Array<ObjectRef> json_decisions; json_insts.reserve(this->insts.size()); json_decisions.reserve(this->insts.size()); int i = 0; for (const Instruction& inst : this->insts) { const InstructionKind& kind = inst->kind; if (remove_postproc && IsPostproc(kind)) { break; } json_insts.push_back(Array<ObjectRef>{ /* 0: inst name */ kind->name, /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) : ObjectRef(inst->attrs), /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), }); if (Optional<ObjectRef> decision = this->GetDecision(inst)) { json_decisions.push_back(Array<ObjectRef>{ /* 0: index */ Integer(i), /* 1: decision */ decision.value(), }); } ++i; } return Array<ObjectRef>{ /* 0: trace */ std::move(json_insts), /* 1: decision */ std::move(json_decisions), }; } Array<String> TraceNode::AsPython(bool remove_postproc) const { std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> rv_names; Array<String> py_trace; py_trace.reserve(this->insts.size()); for (const Instruction& inst : this->insts) { if (remove_postproc && IsPostproc(inst->kind)) { break; } Array<ObjectRef> attrs; attrs.reserve(inst->attrs.size()); for (const ObjectRef& obj : inst->attrs) { if (const auto* str = obj.as<StringObj>()) { attrs.push_back(String('"' + std::string(str->data) + '"')); } else { attrs.push_back(obj); } } py_trace.push_back( inst->kind->f_as_python(/*inputs=*/TranslateInputRVs(inst->inputs, rv_names), /*attrs=*/attrs, /*decision=*/this->GetDecision(inst), /*outputs=*/TranslateAddOutputRVs(inst->outputs, &rv_names))); } return py_trace; } void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array<ObjectRef> json_insts{nullptr}; Array<ObjectRef> json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` try { const ArrayNode* arr = json.as<ArrayNode>(); ICHECK(arr && arr->size() == 2); const auto* arr0 = arr->at(0).as<ArrayNode>(); const auto* arr1 = arr->at(1).as<ArrayNode>(); ICHECK(arr0 && arr1); json_insts = GetRef<Array<ObjectRef>>(arr0); json_decisions = GetRef<Array<ObjectRef>>(arr1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " "instructions and an array of decisions, but gets: " << json; throw; } // Parse `json_decisions` std::vector<Optional<ObjectRef>> decisions(json_insts.size(), NullOpt); for (const ObjectRef& decision_entry : json_decisions) { int index = -1; ObjectRef decision{nullptr}; try { const ArrayNode* arr = decision_entry.as<ArrayNode>(); ICHECK(arr && arr->size() == 2); const IntImmNode* arr0 = arr->at(0).as<IntImmNode>(); ICHECK(arr0); index = arr0->value; decision = arr->at(1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " "decision], but gets: " << decision_entry; throw; } decisions[index] = std::move(decision); } // Parse `json_insts` std::unordered_map<std::string, ObjectRef> named_rvs{{"None", ObjectRef{nullptr}}}; int i = 0; for (const ObjectRef& inst_entry : json_insts) { InstructionKind kind{nullptr}; Array<ObjectRef> inputs{nullptr}; Array<ObjectRef> attrs{nullptr}; Array<String> outputs{ObjectPtr<Object>{nullptr}}; // Parse the entry try { const auto* arr = inst_entry.as<ArrayNode>(); ICHECK(arr && arr->size() == 4); const auto* arr0 = arr->at(0).as<StringObj>(); const auto* arr1 = arr->at(1).as<ArrayNode>(); const auto* arr2 = arr->at(2).as<ArrayNode>(); const auto* arr3 = arr->at(3).as<ArrayNode>(); ICHECK(arr0 && arr1 && arr2 && arr3); for (const ObjectRef& str : *arr3) { ICHECK(str->IsInstance<StringObj>()); } kind = InstructionKind::Get(arr0->data); inputs = GetRef<Array<ObjectRef>>(arr1); attrs = GetRef<Array<ObjectRef>>(arr2); outputs = GetRef<Array<String>>(arr3); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " "inputs, attrs, outputs], but gets: " << inst_entry; throw; } // Parse inputs inputs = TranslateInputRVs(inputs, named_rvs); // Parse attrs if (kind->f_attrs_from_json != nullptr) { attrs = kind->f_attrs_from_json(attrs); } // Apply to the schedule Array<ObjectRef> new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); // Parse outputs TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); ++i; } } /**************** Creation ****************/ Trace TraceNode::WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const { int n_insts = GetNumValidInstructions(this->insts, remove_postproc); Array<Instruction> new_insts = Array<Instruction>{this->insts.begin(), this->insts.begin() + n_insts}; Map<Instruction, ObjectRef> new_decisions{this->decisions.begin(), this->decisions.end()}; new_decisions.Set(std::move(inst), std::move(decision)); return Trace(new_insts, new_decisions); } Trace TraceNode::Simplified(bool remove_postproc) const { int n_insts = GetNumValidInstructions(this->insts, remove_postproc); std::unordered_set<const Object*> used_rvs; std::vector<Instruction> new_insts; std::unordered_map<Instruction, ObjectRef, ObjectPtrHash, ObjectPtrEqual> new_decisions; new_insts.reserve(n_insts); new_decisions.reserve(this->decisions.size()); for (int inst_idx = n_insts - 1; inst_idx >= 0; --inst_idx) { const Instruction& inst = this->insts[inst_idx]; // Check if all the variables the instruction defined are dead // If so, and the instruction is pure, we can safely remove this instruction bool all_defs_dead = inst->kind->is_pure; if (all_defs_dead) { for (const ObjectRef& obj : inst->outputs) { if (used_rvs.count(obj.get())) { all_defs_dead = false; break; } } } // Remove this instruction if (all_defs_dead) { continue; } // Otherwise this instruction is not dead new_insts.push_back(inst); if (Optional<ObjectRef> decision = this->GetDecision(inst)) { new_decisions.emplace(inst, std::move(decision)); } // Add its inputs as "used" ones for (const ObjectRef& obj : inst->inputs) { if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() || obj->IsInstance<VarNode>()) { used_rvs.insert(obj.get()); continue; } else if (obj->IsInstance<PrimExprNode>()) { PostOrderVisit(obj, [&used_rvs](const ObjectRef& obj) -> void { if (obj->IsInstance<VarNode>()) { used_rvs.insert(obj.get()); } }); } } } return Trace(Array<Instruction>(new_insts.rbegin(), new_insts.rend()), Map<Instruction, ObjectRef>(new_decisions)); } /**************** Repr ****************/ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as<TraceNode>(); ICHECK_NOTNULL(self); Array<String> repr = self->AsPython(/*remove_postproc=*/false); bool is_first = true; for (const String& line : repr) { if (is_first) { is_first = false; } else { p->stream << std::endl; } p->stream << line; } }); /**************** Instruction Registration ****************/ struct EnterPostprocTraits : public UnpackedInstTraits<EnterPostprocTraits> { static constexpr const char* kName = "EnterPostproc"; static constexpr bool kIsPure = false; private: static constexpr size_t kNumInputs = 0; static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } static String UnpackedAsPython(Array<String> outputs) { PythonAPICall py("enter_postproc"); return py.Str(); } template <typename> friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); TVM_REGISTER_GLOBAL("tir.schedule.Trace") .set_body_typed([](Optional<Array<Instruction>> insts, Optional<Map<Instruction, ObjectRef>> decisions) { return Trace(insts.value_or(Array<Instruction>()), decisions.value_or(Map<Instruction, ObjectRef>())); }); TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision") .set_body_method<Trace>(&TraceNode::GetDecision); TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") .set_body_typed([](Trace self, Instruction inst, Optional<ObjectRef> decision) { if (decision.defined()) { return self->Append(inst, decision.value()); } else { return self->Append(inst); } }); TVM_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method<Trace>(&TraceNode::Pop); TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") .set_body_method<Trace>(&TraceNode::ApplyToSchedule); TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method<Trace>(&TraceNode::AsJSON); TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method<Trace>(&TraceNode::AsPython); TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision") .set_body_method<Trace>(&TraceNode::WithDecision); TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method<Trace>(&TraceNode::Simplified); TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") .set_body_typed(Trace::ApplyJSONToSchedule); } // namespace tir } // namespace tvm