/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/backend/contrib/cutlass/codegen.cc * \brief Implementation of CUTLASS codegen. */ #include #include #include #include #include #include #include #include #include #include "../../utils.h" #include "../codegen_c/codegen_c.h" namespace tvm { namespace relay { namespace contrib { using namespace backend; using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}}; constexpr const char* kAnyDim = "Any"; std::string GetDimAsStr(ObjectRef dim) { if (auto d = dim.as()) { return std::to_string(d->value); } return kAnyDim; } inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { for (int i = 0; i < indent; ++i) { os << " "; } os << stmt; } Str2StrMap ArgsCommon(const Map& attrs) { Str2StrMap args; auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); args["ElementInputA"] = dtype_map.at(arg0_dtype); args["ElementInputB"] = dtype_map.at(arg1_dtype); args["ElementOutput"] = dtype_map.at(ret_dtype); args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); args["op_type"] = std::string(attrs["op_type"].as()->data); return args; } Str2StrMap GemmArgsCommon(const Map& attrs) { Str2StrMap args = ArgsCommon(attrs); args["lda"] = std::string(attrs["lda"].as()->data); args["ldb"] = std::string(attrs["ldb"].as()->data); args["ldc"] = std::string(attrs["ldc"].as()->data); return args; } Str2StrMap DenseArgs(const Map& attrs) { Str2StrMap args = GemmArgsCommon(attrs); auto arg0_shape = attrs["arg0_shape"].as(); auto arg1_shape = attrs["arg1_shape"].as(); args["M"] = GetDimAsStr(arg0_shape->at(0)); args["K"] = GetDimAsStr(arg0_shape->at(1)); args["N"] = GetDimAsStr(arg1_shape->at(0)); return args; } Str2StrMap BatchMatmulArgs(const Map& attrs) { Str2StrMap args = GemmArgsCommon(attrs); args["batch"] = GetDimAsStr(attrs["batch"]); args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]); args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]); args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]); auto arg0_shape = attrs["arg0_shape"].as(); auto arg1_shape = attrs["arg1_shape"].as(); args["M"] = GetDimAsStr(arg0_shape->at(1)); args["K"] = GetDimAsStr(arg0_shape->at(2)); args["N"] = GetDimAsStr(arg1_shape->at(1)); return args; } void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, const std::vector& func_args, const std::string& kernel, bool has_bias, bool is_gelu, int m_axis_idx, int n_axis_idx, int k_axis_idx) { CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(gemm_decl, attrs.at("op_def")); CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) { if (attrs.at(axis) == kAnyDim) { return func_args[arg_idx] + "->shape[" + std::to_string(axis_idx) + "]"; } else { return attrs.at(axis); } }; CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n"); CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n"); CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n"); CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); if (is_gelu) { // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); } else { CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); } ICHECK(func_args.size() >= 2); CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); if (has_bias) { ICHECK(func_args.size() >= 3); CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); } CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n"); CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); CutlassPrint(gemm_decl, " problem_size,\n"); } void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) { // Using the arguments, query for extra workspace required for matrix multiplication computation CutlassPrint(gemm_decl, "size_t workspace_size = " + kernel + "::get_workspace_size(arguments);\n"); // Allocate workspace memory CutlassPrint(gemm_decl, "cutlass::device_memory::allocation workspace(workspace_size);\n"); // Instantiate CUTLASS kernel depending on template CutlassPrint(gemm_decl, kernel + " gemm_op;\n"); // Check the problem size is supported or not CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); // Initialize CUTLASS kernel with arguments and workspace pointer CutlassPrint(gemm_decl, "status = gemm_op.initialize(arguments, workspace.get());\n"); CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); // Launch initialized CUTLASS kernel CutlassPrint(gemm_decl, "status = gemm_op();\n"); CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); } std::string DenseOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; bool is_gelu = attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 std::ostringstream gemm_decl; AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); if (has_bias) { CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); } else { CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); } CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); if (has_bias && !is_gelu) { CutlassPrint(gemm_decl, " {alpha},\n"); } else { // For GeLU, we explicitly specify the scale. CutlassPrint(gemm_decl, " {alpha, beta},\n"); } CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices AppendGemmExecute(gemm_decl, "Gemm"); return gemm_decl.str(); } std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { std::ostringstream gemm_decl; AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 1, 2); auto get_batch_stride = [&attrs, &func_args](const std::string& name, int arg0_idx, int arg1_idx, int arg0_axis_idx, int arg1_axis_idx) { if (attrs.at(name) == kAnyDim) { return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) + "] * " + func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) + "]"; } else { return attrs.at(name); } }; CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + ",\n"); CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + ",\n"); CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); CutlassPrint(gemm_decl, " {alpha, beta},\n"); if (attrs.at("batch") == kAnyDim) { CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n"); } else { CutlassPrint(gemm_decl, attrs.at("batch") + "};\n"); } AppendGemmExecute(gemm_decl, "BatchedGemm"); return gemm_decl.str(); } Str2StrMap Conv2dArgs(const Map& attrs) { Str2StrMap args = ArgsCommon(attrs); auto arg0_shape = attrs["arg0_shape"].as(); auto arg1_shape = attrs["arg1_shape"].as(); auto out_shape = attrs["ret_shape"].as(); args["N"] = GetDimAsStr(arg0_shape->at(0)); args["H"] = GetDimAsStr(arg0_shape->at(1)); args["W"] = GetDimAsStr(arg0_shape->at(2)); args["C"] = GetDimAsStr(arg0_shape->at(3)); args["K"] = GetDimAsStr(arg1_shape->at(0)); args["R"] = GetDimAsStr(arg1_shape->at(1)); args["S"] = GetDimAsStr(arg1_shape->at(1)); args["P"] = GetDimAsStr(out_shape->at(1)); args["Q"] = GetDimAsStr(out_shape->at(2)); args["pad_h"] = GetDimAsStr(attrs["padding"].as()->at(0)); args["pad_w"] = GetDimAsStr(attrs["padding"].as()->at(1)); args["stride_h"] = GetDimAsStr(attrs["strides"].as()->at(0)); args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); return args; } std::string Conv2dOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args, bool has_residual_block = false) { bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" && attrs.at("op_type") != "cutlass.conv2d_bias_silu" && attrs.at("op_type") != "cutlass.conv2d_bias_hardswish"; std::ostringstream conv2d_decl; CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(conv2d_decl, attrs.at("op_def")); CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") + " = cutlass::conv::device::ImplicitGemmConvolution<" + attrs.at("op_name") + ">;\n"); CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n"); auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) { if (attrs.at(axis) == kAnyDim) { return var_name + "->shape[" + std::to_string(axis_idx) + "]"; } else { return attrs.at(axis); } }; CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n"); CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n"); CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n"); CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n"); CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n"); CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n"); CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n"); CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n"); CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n"); CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n"); CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n"); CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n"); CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n"); CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); CutlassPrint( conv2d_decl, "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n"); ICHECK(func_args.size() >= 2); CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); if (has_residual_block) { ICHECK(func_args.size() >= 4); CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n"); } else if (has_bias) { ICHECK(func_args.size() >= 3); CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); } CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); if (has_bias && no_bias_scaling && !has_residual_block) { CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); } else { CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); } CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); if (has_residual_block) { CutlassPrint(conv2d_decl, " {static_cast(ptr_residual), layout_C},\n"); } else if (has_bias) { CutlassPrint( conv2d_decl, " {static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n"); } else { CutlassPrint(conv2d_decl, " {static_cast(ptr_out), layout_C},\n"); } CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_D},\n"); if (has_residual_block) { CutlassPrint(conv2d_decl, "{alpha, beta},\n"); CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); } else if (has_bias && no_bias_scaling) { CutlassPrint(conv2d_decl, " {alpha}\n};\n"); } else { CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); } CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); // Allocate workspace memory CutlassPrint(conv2d_decl, "cutlass::device_memory::allocation workspace(workspace_size);\n"); // Check the problem size is supported or not CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); // Initialize CUTLASS kernel with arguments and workspace pointer CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); // Launch initialized CUTLASS kernel CutlassPrint(conv2d_decl, "status = conv2d_op();\n"); CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); return conv2d_decl.str(); } class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { public: CodegenCutlass(const std::string& id, const Map& attrs) { this->ext_func_id_ = id; this->attrs_ = attrs; } std::vector VisitExprDefault_(const Object* op) final { LOG(FATAL) << "Cutlass codegen doesn't support: " << op->GetTypeKey(); return {}; } std::vector VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); Output output; output.name = node->name_hint(); return {output}; } std::vector VisitExpr_(const CallNode* call) final { const auto* func = call->op.as(); ICHECK(func) << "Only composite function is supported for CUTLASS."; GenerateBodyOutput ret = GenerateCompositeFunctionCall(func, call); ext_func_body_.push_back(ret.decl); return ret.outputs; } std::string JIT(const std::vector& out) { code_stream_ << "void " << ext_func_id_ << "_("; for (const auto& arg : ext_func_args_) { code_stream_ << "DLTensor* " << arg->name_hint() << ", "; } for (size_t i = 0; i < out.size() - 1; ++i) { code_stream_ << "DLTensor* out" << i << ", "; } code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n"; this->EnterScope(); // Function body for (auto decl : buf_decl_) { this->PrintIndents(); code_stream_ << decl << "\n"; } code_stream_ << "\n"; for (auto stmt : ext_func_body_) { this->PrintIndents(); code_stream_ << stmt << "\n"; } this->ExitScope(); code_stream_ << "}\n"; this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, const_array_name_, out, true); return code_stream_.str(); } private: std::vector GetArgumentNames(const CallNode* call) { std::vector arg_names; for (size_t i = 0; i < call->args.size(); ++i) { auto res = VisitExpr(call->args[i]); for (const auto& out : res) { arg_names.push_back(out.name); } } return arg_names; } bool IsConv2dResidualBlock(const std::string& func_name) { return func_name.find("conv2d") != std::string::npos && func_name.find("residual") != std::string::npos; } // Is node `x` an ancestor of `y`? bool IsAncestor(const CallNode* x, const CallNode* y) { if (x == y) return true; for (auto arg : y->args) { const CallNode* arg_ptr = arg.as(); if (arg_ptr && IsAncestor(x, arg_ptr)) return true; } return false; } GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, const CallNode* caller) { const auto pattern_name = callee->GetAttr(attr::kComposite); ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; if (pattern_name == "cutlass.dense") { const auto* dense_call = GetRootCall(callee->body.as(), 0, {"nn.dense"}); return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->op.as()->name; const auto* dense_call = GetRootCall(callee->body.as(), 1, {"nn.dense", add_or_bias_add}); return GenerateBody(dense_call, "cutlass_dense_bias", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias_relu") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; const auto* dense_call = GetRootCall(callee->body.as(), 2, {"nn.dense", add_or_bias_add, "nn.relu"}); return GenerateBody(dense_call, "cutlass_dense_bias_relu", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias_gelu_fp16") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; const auto* dense_call = GetRootCall(callee->body.as(), 8, {"nn.dense", add_or_bias_add, "multiply", "cast", "erf", "cast", "multiply", "add", "multiply"}); return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias_gelu_fp32") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; const auto* dense_call = GetRootCall( callee->body.as(), 6, {"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add", "multiply"}); return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.batch_matmul") { const auto* batch_matmul_call = GetRootCall(callee->body.as(), 0, {"nn.batch_matmul"}); return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), BatchMatmulArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d") { const auto* conv2d_call = GetRootCall(callee->body.as(), 0, {"nn.conv2d"}); return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->op.as()->name; const auto* conv2d_call = GetRootCall(callee->body.as(), 1, {"nn.conv2d", add_or_bias_add}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias_relu") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; const auto* conv2d_call = GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias_sigmoid") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; const auto* conv2d_call = GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias_silu") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; const auto* conv2d_call = GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "multiply"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_silu", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias_hardswish") { const CallNode* current_call = callee->body.as(); std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; const auto* conv2d_call = GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "multiply"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (IsConv2dResidualBlock(pattern_name.value())) { const CallNode* current_call = callee->body.as(); bool has_relu = current_call->args.size() == 1; const CallNode* binop = has_relu ? current_call->args[0].as() : current_call; ICHECK(binop->args.size() == 2); // Figure out which of the first or second argument corresponds to the residual input // The root conv2d call can be reached via the other input of the binary op int residual_index; if (binop->args[1].as()) { residual_index = 1; } else if (binop->args[0].as()) { residual_index = 0; } else { const CallNode* lhs = binop->args[0].as(); const CallNode* rhs = binop->args[1].as(); ICHECK(lhs && rhs); // The residual input should be an ancestor of the non-residual input residual_index = IsAncestor(rhs, lhs) ? 1 : 0; } const auto* non_residual_input = binop->args[!residual_index].as(); const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d"); ICHECK(conv2d_call); return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; return {}; } GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name, const std::vector& func_args, const Str2StrMap& attribute_args) { // Make function call with input buffers when visiting arguements ICHECK_GT(func_args.size(), 0); std::ostringstream decl_stream; decl_stream << "(" << func_args[0]; for (size_t i = 1; i < func_args.size(); ++i) { decl_stream << ", " << func_args[i]; } // Analyze the output buffers std::vector out_types; if (root_call->checked_type()->IsInstance()) { auto type_node = root_call->checked_type().as(); for (auto field : type_node->fields) { ICHECK(field->IsInstance()); out_types.push_back(field); } } else if (root_call->checked_type()->IsInstance()) { ICHECK(root_call->checked_type()->IsInstance()); out_types.push_back(root_call->checked_type()); } else { LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false); } GenerateBodyOutput ret; for (const auto& out_type : out_types) { const std::string out = "out" + std::to_string(buf_idx_++); decl_stream << ", " << out; Output output; output.name = out; output.dtype = GetDtypeString(out_type.as()); output.need_copy = false; ret.outputs.push_back(output); } decl_stream << ");"; if (func_name.find("dense") != std::string::npos) { ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); } else if (IsConv2dResidualBlock(func_name)) { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true); } else if (func_name.find("conv2d") != std::string::npos) { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); } return ret; } /*! \brief The id of the external cutlass ext_func. */ std::string ext_func_id_{""}; /*! \brief The attrs of the external cutlass ext_func. */ Map attrs_; /*! * \brief The index to track the output buffer. Each kernel will redirect the * output to a buffer that may be consumed by other kernels. */ int buf_idx_{0}; /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ Array ext_func_args_; /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ std::vector ext_func_body_; /*! \brief The array declared to store the constant values. */ std::string const_array_name_; /*! \brief The declaration of intermediate buffers. */ std::vector buf_decl_; }; // class CodegenCutlass class CutlassModuleCodegen : public CSourceModuleCodegenBase { public: std::pair> GenCutlassFunc(const Function& func) { ICHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. auto sid = GetExtSymbol(func); const auto* attrs = func->attrs.as(); ICHECK(attrs != nullptr); const auto dict = attrs->dict; CodegenCutlass builder(sid, dict); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); return {sid, {}}; } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // create header code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; // cutlass header code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; ICHECK(ref->IsInstance()); auto res = GenCutlassFunc(Downcast(ref)); std::string code = code_stream_.str(); String sym = std::get<0>(res); Array variables = std::get<1>(res); // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; return (*pf)(code, "cu", Array{sym}, variables); } private: /*! \brief The code stream that will be compiled by NVCC */ std::ostringstream code_stream_; }; // CutlassModuleCodegen /*! * \brief The external cutlass compiler/codegen tool. It takes a Relay * expression/module and compile it into a runtime module. */ runtime::Module CutlassCompiler(const ObjectRef& ref) { CutlassModuleCodegen cutlass; return cutlass.CreateCSourceModule(ref); } TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); } // namespace contrib } // namespace relay } // namespace tvm