/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/ir/transform.cc * \brief Infrastructure for transformation passes. */ #include #include #include #include #include #include #include #include #include #include "../runtime/object_internal.h" namespace tvm { namespace transform { using tvm::ReprPrinter; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; /*! \brief The current pass context. */ std::stack context_stack; PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } }; /*! \brief Thread local store to hold the pass context. */ typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { InstrumentEnterPassContext(); PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); InstrumentExitPassContext(); } PassContext PassContext::Current() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { return entry->default_context; } } // linearly scan the pass array to match pass_name bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { if (x == pass_name) return true; } return false; } bool PassContext::PassEnabled(const PassInfo& info) const { if (PassArrayContains(operator->()->disabled_pass, info->name)) { return false; } if (PassArrayContains(operator->()->required_pass, info->name)) { return true; } return operator->()->opt_level >= info->opt_level; } class PassConfigManager { public: void Register(std::string key, uint32_t value_type_index) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_index = value_type_index; info.type_key = runtime::Object::TypeIndex2Key(value_type_index); key2vtype_[key] = info; } // Trying to validate and legalize a config. void Legalize(Map* config) { std::vector> update; auto* reflection = ReflectionVTable::Global(); for (auto kv : *config) { auto it = key2vtype_.find(kv.first); if (it == key2vtype_.end()) { std::ostringstream os; os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; int counter = 0; for (const auto& kv : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; os << kv.first; } LOG(FATAL) << os.str(); } const auto& info = it->second; ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; if (kv.second->IsInstance::ContainerType>()) { ObjectRef converted = reflection->CreateObject(info.type_key, Downcast>(kv.second)); update.emplace_back(kv.first, converted); } else { if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " << info.type_key << " but get " << kv.second->GetTypeKey(); } } } for (auto&& kv : update) { config->Set(kv.first, kv.second); } } Map> ListConfigs() { Map> configs; for (const auto& kv : key2vtype_) { Map metadata; metadata.Set("type", kv.second.type_key); configs.Set(kv.first, metadata); } return configs; } static PassConfigManager* Global() { static auto* inst = new PassConfigManager(); return inst; } private: struct ValueTypeInfo { std::string type_key; uint32_t type_index; }; std::unordered_map key2vtype_; }; void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { PassConfigManager::Global()->Register(key, value_type_index); } Map> PassContext::ListConfigs() { return PassConfigManager::Global()->ListConfigs(); } PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::InstrumentEnterPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { Array enter_successes; try { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { pi->EnterPassContext(); enter_successes.push_back(pi); } } catch (const Error& e) { LOG(INFO) << "Pass instrumentation entering pass context failed."; LOG(INFO) << "Disable pass instrumentation."; pass_ctx_node->instruments.clear(); for (instrument::PassInstrument pi : enter_successes) { LOG(INFO) << pi->name << " exiting PassContext ..."; pi->ExitPassContext(); LOG(INFO) << pi->name << " exited PassContext."; } enter_successes.clear(); throw e; } } } void PassContext::InstrumentExitPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { try { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { pi->ExitPassContext(); } } catch (const Error& e) { LOG(INFO) << "Pass instrumentation exiting pass context failed."; pass_ctx_node->instruments.clear(); throw e; } } } bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { auto pass_ctx_node = this->operator->(); if (!pass_ctx_node->instruments.defined()) { return true; } const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); bool should_run = true; if (!pass_required) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { should_run &= pi->ShouldRun(ir_module, pass_info); } } if (should_run) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { pi->RunBeforePass(ir_module, pass_info); } } return should_run; } void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { pi->RunAfterPass(ir_module, pass_info); } } } IRModule Pass::operator()(IRModule mod) const { return this->operator()(std::move(mod), PassContext::Current()); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); const PassInfo& pass_info = node->Info(); if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { DLOG(INFO) << "Skipping pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; return mod; } auto ret = node->operator()(std::move(mod), pass_ctx); pass_ctx.InstrumentAfterPass(ret, pass_info); return std::move(ret); } /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes * at this level have the full control of a given Relay program including * addition and deletion of functions. */ class ModulePassNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; /*! \brief The pass function sketches the real optimization. For example, * we may need to perform dead code elimination on the module level. We could * implement the algorithm in the `pass_func` and let it run on a module. It * will then remove the dead code including the unused functions in the module. */ runtime::TypedPackedFunc pass_func; ModulePassNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a module pass on given pass context. * * \param mod The module that an optimization pass is applied on. * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } static constexpr const char* _type_key = "transform.ModulePass"; TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); }; class ModulePass : public Pass { public: ModulePass(runtime::TypedPackedFunc pass_func, PassInfo pass_info); TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; /*! * \brief The SequentialNode contains a set of passes that transform Relay * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly * perform a host of optimizations with a given optimization level and disabled * passes. */ class SequentialNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); } /*! * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. * * \param mod The module that an optimization pass runs on. * * \return The updated module after resolving pass dependencies. * * TODO(zhiics) Build a dependency graph among the passes using provided * metadata, i.e. required_passes. Likely, we can have a data structure, i.e. * PassInfo, to store the relevant information including the parent passes. */ void ResolveDependency(const IRModule& mod); /*! * \brief Perform optimizations on a series of passes. The aforementioned * typical pass manager jobs could be done by it. This function could * be overloaded to focus on different metrics, i.e. performance, * memory footprint, etc. * * \param mod The module that these passes are applied on. * \param pass_ctx The context that these passes execute on. * * \return Return the updated module. */ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "transform.Sequential"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); data_ = std::move(pass_info); } ModulePass::ModulePass(runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); } // Module -> Module optimizations. IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { DiagnosticContext previous = DiagnosticContext::Default(mod); if (pass_ctx->diag_ctx) { DiagnosticContext tmp = pass_ctx->diag_ctx.value(); pass_ctx->diag_ctx = previous; previous = tmp; } else { pass_ctx->diag_ctx = previous; } ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); ICHECK(mod.defined()) << "The input module must be set."; VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); mod = pass_func(std::move(mod), pass_ctx); ICHECK(mod.defined()) << "The return value of a module pass must be set."; ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; VLOG(1) << "Result module:" << std::endl << PrettyPrint(mod); return mod; } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { auto n = make_object(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfo(0, std::move(name), {}); n->pass_info = std::move(pass_info); data_ = std::move(n); } const SequentialNode* Sequential::operator->() const { return static_cast(get()); } void SequentialNode::ResolveDependency(const IRModule& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. // 3. Build a dependency graph. Probably we need to update the pass list. LOG(FATAL) << "Pass dependency has not been resolved yet." << "\n"; } Pass GetPass(const String& pass_name) { using tvm::runtime::Registry; const runtime::PackedFunc* f = nullptr; if (pass_name.operator std::string().find("transform.") != std::string::npos) { f = Registry::Get(pass_name); } else if ((f = Registry::Get("transform." + pass_name))) { // pass } else if ((f = Registry::Get("relay._transform." + pass_name))) { } ICHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass"; return (*f)(); } // TODO(zhiics): we currently only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); if (!pass_ctx.PassEnabled(pass_info)) { VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; continue; } // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); } mod = pass(std::move(mod), pass_ctx); } return mod; } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") .set_body_typed([](int opt_level, String name, tvm::Array required) { return PassInfo(opt_level, name, required); }); TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "The meta data of the pass: "; p->stream << "pass name: " << node->name; p->stream << "opt_level: " << node->opt_level; p->stream << "required passes: [" << "\n"; for (const auto& it : node->required) { p->stream << it << ", "; } p->stream << "]\n"; }); TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed([](runtime::TypedPackedFunc pass_func, PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) { return pass(std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Module pass: " << info->name << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(SequentialNode); TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; PassInfo pass_info = PassInfo(opt_level, name, required); *ret = Sequential(passes, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Sequential pass: " << info->name << " at the optimization level " << info->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { const PassInfo pass_info = it->Info(); p->stream << pass_info->name << " "; } p->stream << "]"; }); TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); pctx->instruments = std::move(instruments); if (config.defined()) { pctx->config = config.value(); } PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\trequired passes: " << node->required_pass << "\n"; p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; p->stream << "\tinstruments: " << node->instruments << "\n"; p->stream << "\tconfig: " << node->config; }); class PassContext::Internal { public: static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); TVM_REGISTER_GLOBAL("transform.OverrideInstruments") .set_body_typed([](PassContext pass_ctx, Array instruments) { pass_ctx.InstrumentExitPassContext(); pass_ctx->instruments = instruments; pass_ctx.InstrumentEnterPassContext(); }); Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); } TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); } // namespace transform } // namespace tvm