/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include #include #include #include #include #include #include #include #include #include "../../utils.h" #include "codegen_c.h" namespace tvm { namespace relay { namespace contrib { using namespace backend; /*! * \brief An example codegen that is only used for quick prototyping and testing * purpose. Only several binary options are covered. Users * may need to extend them to cover more operators. */ class CodegenC : public MemoizedExprTranslator>, public CodegenCBase { public: explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } std::vector VisitExprDefault_(const Object* op) final { LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); return {}; } std::vector VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); Output output; output.name = node->name_hint(); return {output}; } std::vector VisitExpr_(const TupleNode* node) final { std::vector outs; for (auto field : node->fields) { auto res = VisitExpr(field); ICHECK_EQ(res.size(), 1U) << "Do not support tuple nest"; outs.push_back(res[0]); } return outs; } std::vector VisitExpr_(const TupleGetItemNode* op) final { auto res = VisitExpr(op->tuple); ICHECK_GT(res.size(), static_cast(op->index)); // Only keep the item we want for the child node. // FIXME(@comaniac): The other items should still be requried for the primary outputs. return {res[op->index]}; } std::vector VisitExpr_(const ConstantNode* cn) final { std::ostringstream decl_stream; std::ostringstream buf_stream; Output output; // Get const: static_cast(gcc_0_consts[0]->data) output.name = CreateDataReference(ext_func_id_, const_idx_); const auto* type_node = cn->checked_type().as(); ICHECK(type_node); const auto& dtype = GetDtypeString(type_node); // Generate the global variable for needed ndarrays if (const_array_name_.empty()) { const_array_name_ = CreateNDArrayPool(ext_func_id_); std::string checker = CreateInitChecker(ext_func_id_); ext_func_body_.insert(ext_func_body_.begin(), checker); } ICHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; output.dtype = dtype; std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); const_vars_.push_back(const_var_name); const_idx_++; return {output}; } std::vector VisitExpr_(const CallNode* call) final { std::ostringstream macro_stream; std::ostringstream decl_stream; std::ostringstream buf_stream; std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++); // Make function declaration macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", "; if (IsOp(call, "add")) { macro_stream << "+"; } else if (IsOp(call, "subtract")) { macro_stream << "-"; } else if (IsOp(call, "multiply")) { macro_stream << "*"; } else { LOG(FATAL) << "Unrecognized op"; } auto in_shape = GetShape(call->args[0]->checked_type()); for (size_t i = 0; i < in_shape.size(); ++i) { macro_stream << ", " << in_shape[i]; } const auto* type_node = call->checked_type().as(); ICHECK(type_node); const auto& dtype = GetDtypeString(type_node); macro_stream << ", " << dtype; macro_stream << ");"; func_decl_.push_back(macro_stream.str()); // Make function call when visiting arguments bool first = true; decl_stream << func_name << "("; for (size_t i = 0; i < call->args.size(); ++i) { auto res = VisitExpr(call->args[i]); for (auto out : res) { if (!first) { decl_stream << ", "; } first = false; decl_stream << out.name; } } std::string out = "buf_" + std::to_string(buf_idx_++); auto out_shape = GetShape(call->checked_type()); int out_size = 1; for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; } buf_stream << dtype << "* " << out << " = (" << dtype << "*)malloc(4 * " << out_size << ");"; buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; ext_func_body_.push_back(decl_stream.str()); // Update output buffer // Note C codegen only handles TensorType. Therefore, we don't flatten // tuples and only return a single vaule. Output output; output.name = out; output.dtype = dtype; output.need_copy = true; output.size = out_size; return {output}; } /*! * \brief Emit the source code that invokes C compiler compatible wrappers. * * \return The emitted code. */ std::string JIT(const std::vector& out) { // Write function macros for (auto decl : func_decl_) { code_stream_ << decl << "\n"; } return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); } private: /*! \brief The function id that represents a C source function. */ std::string ext_func_id_ = ""; /*! \brief The index of a wrapped C function. */ int func_idx = 0; /*! \brief The index of allocated buffers. */ int buf_idx_ = 0; /*! \brief The index of global constants. */ int const_idx_ = 0; /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ std::vector ext_func_body_; /*! \brief The array declared to store the constant values. */ std::string const_array_name_; /*! \brief The declaration statements of a C compiler compatible function. */ std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; /*! \brief The variable name to constant mapping. */ Array const_vars_; friend class CSourceCodegen; }; class CSourceCodegen : public CSourceModuleCodegenBase { public: std::tuple, String, String> GenCFunc(const Function& func) { ICHECK(func.defined()) << "Input error: expect a Relay function."; CodegenC builder(GetExtSymbol(func)); auto out = builder.VisitExpr(func->body); return std::make_tuple(builder.const_vars_, builder.ext_func_id_, builder.JIT(out)); } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { ICHECK(ref->IsInstance()); auto res = GenCFunc(Downcast(ref)); Array variables = std::get<0>(res); String func_name = std::get<1>(res); // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; if (!variables.empty()) { // This segment would be generated in C++ because of the usage // of tvm::runtime::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#endif\n"; } // Append some common macro for operator definition. const char* operator_macro = R"op_macro( #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \ void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \ out[i] = a[i] p_OP_ b[i]; \ } \ } #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_, p_DTYPE) \ void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \ for (int64_t j = 0; j < p_DIM2_; ++j) { \ int64_t k = i * p_DIM2_ + j; \ out[k] = a[k] p_OP_ b[k]; \ } \ } \ } )op_macro"; code_stream_ << operator_macro << "\n\n"; code_stream_ << std::get<2>(res); std::string code = code_stream_.str(); // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; return (*pf)(code, "c", Array{func_name}, variables); } private: std::ostringstream code_stream_; }; /*! * \brief The external compiler/codegen tool. It takes a Relay expression/module and * compile it into a runtime module. * * The external codegen tool should have been registered similiarly to LLVM, * CUDA, etc, under TVM, so the generated code could be packed in a runtime * module. This module simplifies code serialization and invocation. */ runtime::Module CCompiler(const ObjectRef& ref) { CSourceCodegen csource; return csource.CreateCSourceModule(ref); } TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); } // namespace contrib } // namespace relay } // namespace tvm