/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file llvm_module.cc * \brief LLVM runtime module for TVM */ #ifdef TVM_LLVM_VERSION #include <tvm/ir/module.h> #include <tvm/relay/runtime.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include <tvm/target/codegen.h> #include <mutex> #include "../../runtime/file_utils.h" #include "../../runtime/library_module.h" #include "../func_registry_generator.h" #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" #include "llvm_common.h" namespace tvm { namespace codegen { using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; class LLVMModuleNode final : public runtime::ModuleNode { public: ~LLVMModuleNode() { module_.reset(); if (ee_ != nullptr) { ee_->runStaticConstructorsDestructors(true); delete ee_; } } const char* type_key() const { return "llvm"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { if (name == "__tvm_is_system_module") { bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); } else if (name == "get_func_names") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); } else if (name == "get_symbol") { return PackedFunc(nullptr); } else if (name == "get_const_vars") { return PackedFunc(nullptr); } else if (name == "_get_target_triple") { std::ostringstream target_triple_ss; target_triple_ss << tm_->getTargetTriple().str(); // getTargetTriple() doesn't include other flags besides the triple. Add back flags which are // important for ModulePackImportsToLLVM. if (tm_->Options.FloatABIType == llvm::FloatABI::ABIType::Soft) { target_triple_ss << " -mfloat-abi=soft"; } std::string mabi = tm_->Options.MCOptions.ABIName; if (!mabi.empty()) { target_triple_ss << " -mabi=" << mabi; } llvm::StringRef mcpu = tm_->getTargetCPU(); if (!mcpu.empty() && mcpu != "generic") { target_triple_ss << " -mcpu=" << mcpu.str(); } std::string target_triple = target_triple_ss.str(); return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); std::lock_guard<std::mutex> lock(mutex_); TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main)); ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name)); } else { faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(name)); } if (faddr == nullptr) return PackedFunc(); return WrapPackedFunc(faddr, sptr_to_self); } void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; #if TVM_LLVM_VERSION <= 70 llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); #else llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); #endif ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_); #else std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_); #endif llvm::legacy::PassManager pass; ICHECK(tm_); #if TVM_LLVM_VERSION <= 60 ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #else ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*m); } else if (fmt == "s" || fmt == "asm") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_); #else std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_); #endif llvm::legacy::PassManager pass; ICHECK(tm_); #if TVM_LLVM_VERSION <= 60 ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif pass.run(*m); } else if (fmt == "ll") { mptr_->print(dest, nullptr); } else if (fmt == "bc") { #if TVM_LLVM_VERSION <= 60 llvm::WriteBitcodeToFile(mptr_, dest); #else llvm::WriteBitcodeToFile(*mptr_, dest); #endif } else { LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format << "\'"; } dest.close(); } void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; } std::string GetSource(const std::string& format) final { std::string fmt = runtime::GetFileFormat("", format); std::string type_str; llvm::SmallString<256> str; llvm::raw_svector_ostream rso(str); if (fmt == "s" || fmt == "asm") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_); #else std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_); #endif llvm::legacy::PassManager pass; ICHECK(tm_); #if TVM_LLVM_VERSION <= 60 ICHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif pass.run(*m); return rso.str().str(); } else if (fmt == "" || fmt == "ll") { std::string type_str; llvm::raw_string_ostream rso(type_str); ICHECK(mptr_ != nullptr); mptr_->print(rso, nullptr); return rso.str(); } else { LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } return ""; } void Init(const IRModule& mod, const Target& target) { InitializeLLVM(); tm_ = GetLLVMTargetMachine(target); ctx_ = std::make_shared<llvm::LLVMContext>(); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get()); std::vector<PrimFunc> funcs; std::string entry_func; Map<String, LinkedParam> linked_params; bool found_linked_params = false; bool could_have_linked_params = mod->ShouldLinkParameters(); relay::Runtime runtime = mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; for (auto kv : mod->functions) { if (could_have_linked_params && kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { Map<String, ObjectRef> attrs_dict = Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict); CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) << "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; linked_params = Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]); found_linked_params = true; continue; } if (!kv.second->IsInstance<PrimFuncNode>()) { // (@jroesch): we relax constraints here, Relay functions will just be ignored. DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); continue; } auto f = Downcast<PrimFunc>(kv.second); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); function_names_.push_back(global_symbol.value()); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { entry_func = global_symbol.value(); } funcs.push_back(f); } // TODO(@jroesch): follow up on this condition. // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); // See https://llvm.org/docs/LangRef.html#fast-math-flags for details Bool fast_math_all = target->GetAttr<Bool>("fast-math").value_or(Bool(false)); Bool fast_math_nnan = target->GetAttr<Bool>("fast-math-nnan").value_or(Bool(false)); Bool fast_math_ninf = target->GetAttr<Bool>("fast-math-ninf").value_or(Bool(false)); Bool fast_math_nsz = target->GetAttr<Bool>("fast-math-nsz").value_or(Bool(false)); Bool fast_math_arcp = target->GetAttr<Bool>("fast-math-arcp").value_or(Bool(false)); llvm::FastMathFlags fmf; if (fast_math_all) { #if TVM_LLVM_VERSION >= 60 fmf.setFast(); #else fmf.setUnsafeAlgebra(); #endif } if (fast_math_nnan) { fmf.setNoNaNs(); } if (fast_math_ninf) { fmf.setNoInfs(); } if (fast_math_nsz) { fmf.setNoSignedZeros(); } if (fast_math_arcp) { fmf.setAllowReciprocal(); } #if TVM_LLVM_VERSION >= 60 Bool fast_math_contract = target->GetAttr<Bool>("fast-math-contract").value_or(Bool(false)); Bool fast_math_afn = target->GetAttr<Bool>("fast-math-afn").value_or(Bool(false)); Bool fast_math_reassoc = target->GetAttr<Bool>("fast-math-reassoc").value_or(Bool(false)); if (fast_math_contract) { fmf.setAllowContract(true); } if (fast_math_afn) { fmf.setApproxFunc(); } if (fast_math_reassoc) { fmf.setAllowReassoc(); } #endif cg->SetFastMathFlag(fmf); cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } if (found_linked_params) { cg->LinkParameters(linked_params); } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); if (tm_->getTargetTriple().isOSDarwin()) { module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) << "LLVM module verification failed with the following errors: \n" << verify_errors.str(); target_ = target; mptr_ = module_.get(); } void Init(std::unique_ptr<llvm::Module> module, std::shared_ptr<llvm::LLVMContext> ctx) { InitializeLLVM(); ctx_ = ctx; llvm::SMDiagnostic err; module_ = std::move(module); if (module_ == nullptr) { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load module: " << msg; } std::string target_metadata; llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target"); if (tvm_target != nullptr) { llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(tvm_target); ICHECK(pstr != nullptr); target_metadata = pstr->getString().str(); if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) { target_metadata = "llvm " + target_metadata; } } else { std::ostringstream os; os << "llvm -mtriple " << module_->getTargetTriple(); target_metadata = os.str(); } mptr_ = module_.get(); tm_ = GetLLVMTargetMachine(Target(target_metadata)); } void LoadIR(const std::string& file_name) { auto ctx = std::make_shared<llvm::LLVMContext>(); llvm::SMDiagnostic err; auto module = llvm::parseIRFile(file_name, err, *ctx); if (module == nullptr) { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load ir file " << file_name << "\n" << "line " << err.getLineNo() << ":" << msg; } Init(std::move(module), ctx); } private: void LazyInitJIT() { std::lock_guard<std::mutex> lock(mutex_); if (ee_) { return; } if (!target_.defined()) { target_ = Target("llvm"); } llvm::EngineBuilder builder(std::move(module_)); std::string triple, mcpu, mattr; llvm::TargetOptions opt; ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt); builder.setEngineKind(llvm::EngineKind::JIT); builder.setOptLevel(llvm::CodeGenOpt::Aggressive); if (mcpu.length() != 0) { builder.setMCPU(mcpu); } if (mattr.length() != 0) { std::vector<std::string> mattrs{mattr}; builder.setMAttrs(mattrs); } builder.setTargetOptions(opt); auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget()); std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine(Target("llvm")); if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { LOG(FATAL) << "Cannot run module, architecture mismatch " << " module=" << tm->getTargetTriple().str() << " system=" << tm_sys->getTargetTriple().str(); } llvm::DataLayout layout(tm->createDataLayout()); ICHECK(layout == mptr_->getDataLayout()) << "Data layout mismatch between module(" << mptr_->getDataLayout().getStringRepresentation() << ")" << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; ee_ = builder.create(tm.release()); ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); if (void** ctx_addr = reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { *ctx_addr = this; } runtime::InitContextFunctions( [this](const char* name) { return reinterpret_cast<void*>(GetGlobalAddr(name)); }); } // Get global address from execution engine. uint64_t GetGlobalAddr(const std::string& name) const { // first verifies if GV exists. if (mptr_->getGlobalVariable(name) != nullptr) { return ee_->getGlobalValueAddress(name); } else { return 0; } } uint64_t GetFunctionAddr(const std::string& name) const { // first verifies if GV exists. if (mptr_->getFunction(name) != nullptr) { return ee_->getFunctionAddress(name); } else { return 0; } } // The target configuration string Target target_; // JIT lock std::mutex mutex_; // execution engine llvm::ExecutionEngine* ee_{nullptr}; // The raw pointer to the module. llvm::Module* mptr_{nullptr}; // The target machine std::unique_ptr<llvm::TargetMachine> tm_{nullptr}; // The module, can be moved to ee if JIT is enabled. std::unique_ptr<llvm::Module> module_; // the context. std::shared_ptr<llvm::LLVMContext> ctx_; /* \brief names of the functions declared in this module */ Array<String> function_names_; }; TVM_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { auto n = make_object<LLVMModuleNode>(); n->Init(mod, target); return runtime::Module(n); }); TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { Target target = Target(target_str); auto n = make_object<LLVMModuleNode>(); // Generate a LLVM module from an input target string InitializeLLVM(); auto tm = GetLLVMTargetMachine(target); auto ctx = std::make_shared<llvm::LLVMContext>(); std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx)); // Use a default data layout and target triple auto triple = tm->getTargetTriple(); module->setTargetTriple(triple.str()); module->setDataLayout(tm->createDataLayout()); n->Init(std::move(module), ctx); return runtime::Module(n); }); TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") .set_body_typed([](std::string name) -> int64_t { return static_cast<int64_t>(llvm::Function::lookupIntrinsicID(name)); }); TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { return TVM_LLVM_VERSION / 10; }); TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { auto n = make_object<LLVMModuleNode>(); n->LoadIR(filename); return runtime::Module(n); }); TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { InitializeLLVM(); Target target = Target(target_str); return (GetLLVMTargetMachine(target, true) != nullptr); }); TVM_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string target_triple) -> runtime::Module { auto n = make_object<LLVMModuleNode>(); auto p = CodeGenBlob(data, system_lib, target_triple); n->Init(std::move(p.first), p.second); return runtime::Module(n); }); runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target, tvm::relay::Runtime runtime) { Array<String> func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); if (pf_funcs != nullptr) { Array<String> func_names_ = pf_funcs(); for (const auto& fname : func_names_) { func_names.push_back(fname); } } } InitializeLLVM(); auto tm = GetLLVMTargetMachine(target); bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; " << "got target: " << target->str(); auto ctx = std::make_shared<llvm::LLVMContext>(); std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()}; cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); cg->DefineFunctionRegistry(func_names); auto mod = cg->Finish(); mod->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx, LLVMTargetToString(target))); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); if (tm->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) << "LLVM module verification failed with the following errors: \n" << verify_errors.str(); auto n = make_object<LLVMModuleNode>(); n->Init(std::move(mod), ctx); for (auto m : modules) { n->Import(m); } return runtime::Module(n); } TVM_REGISTER_GLOBAL("runtime.CreateLLVMCrtMetadataModule") .set_body_typed(CreateLLVMCrtMetadataModule); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION