/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file module_util.cc * \brief Utilities for module. */ #include "library_module.h" #include #include #include #include #include #include namespace tvm { namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: explicit LibraryModuleNode(ObjectPtr lib, PackedFuncWrapper wrapper) : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = reinterpret_cast(lib_->GetSymbol(runtime::symbol::tvm_module_main)); ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); } else { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } if (faddr == nullptr) return PackedFunc(); return packed_func_wrapper_(faddr, sptr_to_self); } private: ObjectPtr lib_; PackedFuncWrapper packed_func_wrapper_; }; /*! * \brief Helper classes to get into internal of a module. */ class ModuleInternal { public: // Get mutable reference of imports. static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { TVMValue ret_value; int ret_type_code = kTVMNullptr; int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), args.num_args, &ret_value, &ret_type_code, nullptr); ICHECK_EQ(ret, 0) << TVMGetLastError(); if (ret_type_code != kTVMNullptr) { *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); } }); } void InitContextFunctions(std::function fgetsymbol) { #define TVM_INIT_CONTEXT_FUNC(FuncName) \ if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ *fp = FuncName; \ } // Initialize the functions TVM_INIT_CONTEXT_FUNC(TVMFuncCall); TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); #undef TVM_INIT_CONTEXT_FUNC } Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { std::string loadkey = "runtime.module.loadbinary_"; std::string fkey = loadkey + type_key; const PackedFunc* f = Registry::Get(fkey); if (f == nullptr) { std::string loaders = ""; for (auto name : Registry::ListNames()) { if (name.find(loadkey, 0) == 0) { if (loaders.size() > 0) { loaders += ", "; } loaders += name.substr(loadkey.size()); } } LOG(FATAL) << "Binary was created using " << type_key << " but a loader of that name is not registered. Available loaders are " << loaders << ". Perhaps you need to recompile with this runtime enabled."; } return (*f)(static_cast(stream)); } /*! * \brief Load and append module blob to module list * \param mblob The module blob. * \param lib The library. * \param root_module the output root module * \param dso_ctx_addr the output dso module */ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { uint64_t c = mblob[i]; nbytes |= (c & 0xffUL) << (i * 8); } dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); dmlc::Stream* stream = &fs; uint64_t size; ICHECK(stream->Read(&size)); std::vector modules; std::vector import_tree_row_ptr; std::vector import_tree_child_indices; int num_dso_module = 0; for (uint64_t i = 0; i < size; ++i) { std::string tkey; ICHECK(stream->Read(&tkey)); // "_lib" serves as a placeholder in the module import tree to indicate where // to place the DSOModule if (tkey == "_lib") { auto dso_module = Module(make_object(lib, packed_func_wrapper)); *dso_ctx_addr = dso_module.operator->(); ++num_dso_module; modules.emplace_back(dso_module); ICHECK_EQ(num_dso_module, 1U) << "Multiple dso module detected, please upgrade tvm " << " to the latest before exporting the module"; } else if (tkey == "_import_tree") { ICHECK(stream->Read(&import_tree_row_ptr)); ICHECK(stream->Read(&import_tree_child_indices)); } else { auto m = LoadModuleFromBinary(tkey, stream); modules.emplace_back(m); } } // if we are using old dll, we don't have import tree // so that we can't reconstruct module relationship using import tree if (import_tree_row_ptr.empty()) { auto n = make_object(lib, packed_func_wrapper); auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); for (const auto& m : modules) { module_import_addr->emplace_back(m); } *dso_ctx_addr = n.get(); *root_module = Module(n); } else { for (size_t i = 0; i < modules.size(); ++i) { for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->()); auto child_index = import_tree_child_indices[j]; ICHECK(child_index < modules.size()); module_import_addr->emplace_back(modules[child_index]); } } ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; // invariance: root module is always at location 0. // The module order is collected via DFS *root_module = modules[0]; } } Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib, packed_func_wrapper); // Load the imported modules const char* dev_mblob = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); Module root_mod; runtime::ModuleNode* dso_ctx_addr = nullptr; if (dev_mblob != nullptr) { ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); } else { // Only have one single DSO Module root_mod = Module(n); dso_ctx_addr = root_mod.operator->(); } // allow lookup of symbol from root (so all symbols are visible). if (auto* ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { *ctx_addr = dso_ctx_addr; } return root_mod; } TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectPtr n = CreateDSOLibraryObject(args[0]); *rv = CreateModuleFromLibrary(n); }); } // namespace runtime } // namespace tvm