/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file interface_c.cc * \brief Generates a C interface header for a given modules inputs and outputs */ #include #include #include #include #include #include #include "../../relay/backend/name_transforms.h" namespace tvm { namespace codegen { using runtime::PackedFunc; using namespace tvm::relay::backend; class InterfaceCNode : public runtime::ModuleNode { public: InterfaceCNode(std::string module_name, Array inputs, Array outputs, Array devices, int workspace_size) : module_name_(module_name), inputs_(inputs), outputs_(outputs), devices_(devices), workspace_size_(workspace_size) {} const char* type_key() const { return "h"; } std::string GetSource(const std::string& format) final { std::stringstream code; EmitUpperHeaderGuard(code); EmitBrief(code, "Input tensor pointers"); EmitStruct(code, "inputs", inputs_); EmitBrief(code, "Output tensor pointers"); EmitStruct(code, "outputs", outputs_); if (!devices_.empty()) { EmitBrief(code, "Device context pointers"); EmitStruct(code, "devices", devices_); } EmitRunFunction(code); EmitWorkspaceSize(code); EmitLowerHeaderGuard(code); return code.str(); } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { return PackedFunc(nullptr); } private: void EmitUpperHeaderGuard(std::stringstream& code_stream) { std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"})); code_stream << "#ifndef " << header_guard_name << "_\n" << "#define " << header_guard_name << "_\n" << "#include \n\n" << "#ifdef __cplusplus\n" << "extern \"C\" {\n" << "#endif\n\n"; } void EmitLowerHeaderGuard(std::stringstream& code_stream) { std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"})); code_stream << "\n#ifdef __cplusplus\n" << "}\n" << "#endif\n\n" << "#endif // " << header_guard_name << "_\n"; } void EmitBrief(std::stringstream& code_stream, const std::string& description) { code_stream << "/*!\n" << " * \\brief " << description << " for TVM module \"" << module_name_ << "\" \n" << " */\n"; } void EmitStruct(std::stringstream& code_stream, const std::string& suffix, Array properties) { std::string struct_name = ToCVariableStyle(PrefixGeneratedName({module_name_, suffix})); code_stream << "struct " << struct_name << " {\n"; std::vector sanitized_properties; for (const String& property : properties) { std::string sanitized_property = SanitizeName(property); ICHECK(std::find(sanitized_properties.begin(), sanitized_properties.end(), sanitized_property) == sanitized_properties.end()) << "Sanitized input tensor name clash" << sanitized_property; code_stream << " void* " << sanitized_property << ";\n"; sanitized_properties.push_back(sanitized_property); } code_stream << "};\n\n"; } void EmitRunFunction(std::stringstream& code_stream) { std::string run_function = ToCVariableStyle(PrefixGeneratedName({module_name_, "run"})); std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"})); std::string outputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"})); std::string devices_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "devices"})); code_stream << "/*!\n" << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n" << " * \\param inputs Input tensors for the module \n" << " * \\param outputs Output tensors for the module \n"; if (!devices_.empty()) { code_stream << " * \\param devices Device context pointers for the module \n"; } code_stream << " */\n" << "int32_t " << run_function << "(\n" << " struct " << inputs_struct << "* inputs,\n"; if (!devices_.empty()) { code_stream << " struct " << outputs_struct << "* outputs,\n"; code_stream << " struct " << devices_struct << "* devices\n"; } else { code_stream << " struct " << outputs_struct << "* outputs\n"; } code_stream << ");\n"; } void EmitWorkspaceSize(std::stringstream& code_stream) { std::string workspace_size_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "WORKSPACE_SIZE"})); code_stream << "/*!\n" << " * \\brief Workspace size for TVM module \"" << module_name_ << "\"\n" << " */\n" << "#define " << workspace_size_name << " " << workspace_size_ << "\n"; } std::string module_name_; Array inputs_; Array outputs_; Array devices_; int workspace_size_; }; runtime::Module InterfaceCCreate(std::string module_name, Array inputs, Array outputs, Array devices, int workspace_size) { auto n = make_object(module_name, inputs, outputs, devices, workspace_size); return runtime::Module(n); } TVM_REGISTER_GLOBAL("runtime.InterfaceCCreate").set_body_typed(InterfaceCCreate); } // namespace codegen } // namespace tvm