/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/backend/contrib/ethosn/codegen.cc * \brief The Relay -> Arm(R) Ethos(TM)-N command stream compiler. */ #include #include #include "codegen_ethosn.h" #include "ethosn_api.h" namespace tvm { namespace relay { namespace contrib { namespace ethosn { sl::TensorInfo GetTensorInfo(std::map> tensor_table, const Call& call) { if (tensor_table.find(call) != tensor_table.end()) return tensor_table[call][0]; return sl::TensorInfo(); } bool IsEthosnOp(const Call& call, const std::string& op_name) { if (call->op->IsInstance()) { Op op = Downcast(call->op); ICHECK(op.defined()); return op == Op::Get(op_name); } else { return false; } } bool IsEthosnFunc(const Call& call, const std::string& op_name) { if (call->op->IsInstance()) { Function func = Downcast(call->op); ICHECK(func.defined()); auto name_node = func->GetAttr(attr::kComposite); return name_node.value() == op_name; } return false; } std::map> InferTensorsVisitor::Infer(const Expr& expr) { tensor_table_.clear(); ICHECK(expr->checked_type().defined()); size_t output_size = 1; if (auto tuple = expr->checked_type().as()) { output_size = tuple->fields.size(); } for (size_t i = 0; i < output_size; i++) { tensor_table_[expr].push_back(sl::TensorInfo({1, 1, 1, 1}, sl::DataType::UINT8_QUANTIZED, sl::DataFormat::NHWC, sl::QuantizationInfo())); } VisitInferred(expr); return tensor_table_; } void InferTensorsVisitor::InferCall(const CallNode* cn) { EthosnError err; Call call = GetRef(cn); // Determine call -> NPU mapping if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) { ConvolutionParams params; err += EthosnAPI::QnnConv2d(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.activation_info}; } else if (IsEthosnFunc(call, "ethos-n.qnn_fc")) { FullyConnectedParams params; err += EthosnAPI::QnnFullyConnected(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "nn.max_pool2d")) { MaxPool2DParams params; params.input_info = GetTensorInfo(tensor_table_, call); err += EthosnAPI::MaxPool2D(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnFunc(call, "ethos-n.qnn_avg_pool2d")) { AvgPool2DParams params; params.input_info = GetTensorInfo(tensor_table_, call); err += EthosnAPI::AvgPool2D(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "reshape")) { ReshapeParams params; params.input_info = GetTensorInfo(tensor_table_, call); err += EthosnAPI::Reshape(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "qnn.add")) { AdditionParams params; err += EthosnAPI::Addition(call, ¶ms); tensor_table_[cn->args[0]] = {params.lhs_info}; tensor_table_[cn->args[1]] = {params.rhs_info}; } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) { SigmoidParams params; err += EthosnAPI::Sigmoid(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "qnn.concatenate")) { ConcatenateParams params; err = EthosnAPI::Concatenate(call, ¶ms); tensor_table_[cn->args[0]] = params.input_infos; } else if (IsEthosnOp(call, "split")) { SplitParams params; params.input_info = GetTensorInfo(tensor_table_, call); err = EthosnAPI::Split(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "nn.depth_to_space")) { DepthToSpaceParams params; params.input_info = GetTensorInfo(tensor_table_, call); err += EthosnAPI::DepthToSpace(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "clip")) { ReluParams params; params.input_info = GetTensorInfo(tensor_table_, call); err = EthosnAPI::Relu(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; } else { err = EthosnError("unknown operator"); } if (err) { ReportFatalError(call, err); } } // This will only visit an expression if the expression's tensor info // has already been entirely inferred. // An example where this is important is a tuple node where each // get item node will only infer one field of the tuple's expression info. // We don't want to traverse the tuple until all of its fields have been inferred. void InferTensorsVisitor::VisitInferred(const Expr& expr) { if (tensor_table_.find(expr) != tensor_table_.end()) { for (const auto& tensor_info : tensor_table_[expr]) { if (tensor_info == sl::TensorInfo()) return; } VisitExpr(expr); } } void InferTensorsVisitor::VisitExpr_(const CallNode* cn) { InferCall(cn); // Pre-order visitor for (const auto& arg : cn->args) { VisitInferred(arg); } } void InferTensorsVisitor::VisitExpr_(const TupleNode* tn) { auto tuple = GetRef(tn); ICHECK(tensor_table_.find(tuple) != tensor_table_.end()); for (size_t i = 0; i < tn->fields.size(); i++) { tensor_table_[tn->fields[i]] = {tensor_table_[tuple][i]}; } // Pre-order visitor for (const auto& field : tn->fields) { VisitExpr(field); } } void InferTensorsVisitor::VisitExpr_(const TupleGetItemNode* tgn) { // Don't assume it must be targeting a TupleNode // Vars and calls can still have TupleType auto tg = GetRef(tgn); ICHECK(tensor_table_.find(tg) != tensor_table_.end()); auto tuple = tg->tuple; auto type = tuple->checked_type().as(); int index = tg->index; // Resize the tensor infos to the tuple size if not already done if (tensor_table_.find(tuple) == tensor_table_.end()) { tensor_table_[tuple].resize(type->fields.size()); } tensor_table_[tuple][index] = tensor_table_[tg][0]; // Pre-order visitor VisitInferred(tuple); } sl::TensorsAndId MakeOps(const sl::TensorAndId& op) { sl::TensorsAndId ops; ops.tensors = {op.tensor}; ops.operationId = op.operationId; return ops; } String MakeVariant(auto configuration) { String variant = configuration.value()->variant; // Transform variant string to lowercase for comparison std::string variant_string = variant.c_str(); std::transform(variant_string.begin(), variant_string.end(), variant_string.begin(), ::tolower); std::string variant_n78 = "ethos-n78"; if (variant_string == variant_n78) { String tops = configuration.value()->tops; String ple_ratio = configuration.value()->ple_ratio; variant = "Ethos-N78_" + tops + "TOPS_" + ple_ratio + "PLE_RATIO"; } return variant; } NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { // Initialise everything auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.ethos-n.options"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } NetworkWithIDs network_with_ids; network_ = sl::CreateNetwork( sl::GetFwAndHwCapabilities(sl::EthosNVariantFromString(MakeVariant(cfg).c_str()), static_cast(std::stoul(cfg.value()->sram_size)))); network_with_ids.network = network_; operand_table_.clear(); // Infer tensor information tensor_table_ = InferTensors(this->mod_, this->var_, func->body); // Add the inputs in the order they appear in the parameters unsigned int idx = 0; for (const auto& param : func->params) { for (const auto& tensor_info : tensor_table_[param]) { auto tensor_and_id = AddInput(network_, tensor_info); operand_table_[param].push_back(tensor_and_id.tensor); id_table_[param].push_back(std::make_pair(tensor_and_id.operationId, 0)); network_with_ids.input_ids[tensor_and_id.operationId] = idx++; } } // Add the function body VisitExpr(func->body); // Add the outputs idx = 0; for (const auto& layer : operand_table_[func->body]) { AddOutput(network_, *layer); network_with_ids.output_ids[id_table_[func->body][idx]] = idx; idx++; } return network_with_ids; } sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) { EthosnError err; Call call = GetRef(cn); sl::TensorAndId tensor; sl::TensorsAndId tensors; // Determine call -> NPU mapping if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) { if ((err = MakeConvolutionLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnFunc(call, "ethos-n.qnn_fc")) { if ((err = MakeFullyConnectedLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "nn.max_pool2d")) { if ((err = MakeMaxPool2DLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnFunc(call, "ethos-n.qnn_avg_pool2d")) { if ((err = MakeAvgPool2DLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "reshape")) { if ((err = MakeReshapeLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "qnn.add")) { if ((err = MakeAdditionLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) { if ((err = MakeSigmoidLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "qnn.concatenate")) { if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "split")) { if ((err = MakeSplitLayer(call, &tensors))) ReportFatalError(call, err); return tensors; } else if (IsEthosnOp(call, "nn.depth_to_space")) { if ((err = MakeDepthToSpaceLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnOp(call, "clip")) { if ((err = MakeReluLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else { ReportFatalError(call, EthosnError("unknown operator")); return {}; } } void ConstructNetworkVisitor::VisitExpr_(const CallNode* cn) { auto operand = HandleCall(cn); operand_table_[GetRef(cn)] = operand.tensors; for (size_t i = 0; i < operand.tensors.size(); i++) { id_table_[GetRef(cn)].push_back(std::make_pair(operand.operationId, i)); } } void ConstructNetworkVisitor::VisitExpr_(const TupleNode* op) { Tuple tuple = GetRef(op); for (const auto& arg : tuple->fields) { // The fields in a tuple should not themselves be tuples // Nested tuples are not supported if (operand_table_[arg].size() == 1) { operand_table_[tuple].push_back(operand_table_[arg][0]); id_table_[tuple].push_back(id_table_[arg][0]); } else { operand_table_[tuple].push_back(nullptr); id_table_[tuple].push_back(std::make_pair(0, 0)); } } } void ConstructNetworkVisitor::VisitExpr_(const TupleGetItemNode* tg) { Expr tuple = tg->tuple; operand_table_[GetRef(tg)] = {operand_table_[tuple][tg->index]}; id_table_[GetRef(tg)] = {id_table_[tuple][tg->index]}; } void ConstructNetworkVisitor::VisitLeaf(const Expr& expr) { // Don't traverse into functions, they're not supported if (!expr->IsInstance()) MixedModeVisitor::VisitLeaf(expr); } EthosnError ConstructNetworkVisitor::MakeConvolutionLayer(const Call& call, sl::TensorAndId* out) { ConvolutionParams params; if (auto err = EthosnAPI::QnnConv2d(call->op.as()->body, ¶ms)) { return err; } auto activation = operand_table_[call->args[0]][0]; auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor; auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor; try { if (params.is_depthwise) { *out = AddDepthwiseConvolution(network_, *activation, *bias, *weights, params.conv_info); } else { *out = AddConvolution(network_, *activation, *bias, *weights, params.conv_info); } } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeFullyConnectedLayer(const Call& call, sl::TensorAndId* out) { FullyConnectedParams params; if (auto err = EthosnAPI::QnnFullyConnected(call->op.as()->body, ¶ms)) { return err; } auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor; auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor; try { auto input = AddReshape(network_, *operand_table_[call->args[0]][0], params.input_info.m_Dimensions) .tensor; *out = AddFullyConnected(network_, *input, *bias, *weights, params.fc_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeMaxPool2DLayer(const Call& call, sl::TensorAndId* out) { MaxPool2DParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::MaxPool2D(call, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddPooling(network_, *input, params.pool_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeAvgPool2DLayer(const Call& call, sl::TensorAndId* out) { AvgPool2DParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::AvgPool2D(call->op.as()->body, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddPooling(network_, *input, params.pool_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeReshapeLayer(const Call& call, sl::TensorAndId* out) { ReshapeParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::Reshape(call, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddReshape(network_, *input, params.new_shape); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeAdditionLayer(const Call& call, sl::TensorAndId* out) { AdditionParams params; if (auto err = EthosnAPI::Addition(call, ¶ms)) { return err; } auto lhs = operand_table_[call->args[0]][0]; auto rhs = operand_table_[call->args[1]][0]; try { *out = AddAddition(network_, *lhs, *rhs, params.output_quantization_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeSigmoidLayer(const Call& call, sl::TensorAndId* out) { SigmoidParams params; if (auto err = EthosnAPI::Sigmoid(call->op.as()->body, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddSigmoid(network_, *input); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call, sl::TensorAndId* out) { ConcatenateParams params; if (auto err = EthosnAPI::Concatenate(call, ¶ms)) { return err; } std::vector layers; auto ops = operand_table_[call->args[0]]; for (const auto& op : ops) { layers.emplace_back(op.get()); } try { *out = AddConcatenation(network_, layers, params.concat_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeSplitLayer(const Call& call, sl::TensorsAndId* outs) { SplitParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::Split(call, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *outs = AddSplit(network_, *input, params.split_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeDepthToSpaceLayer(const Call& call, sl::TensorAndId* out) { DepthToSpaceParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::DepthToSpace(call, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddDepthToSpace(network_, *input, params.depth_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } EthosnError ConstructNetworkVisitor::MakeReluLayer(const Call& call, sl::TensorAndId* out) { ReluParams params; params.input_info = GetTensorInfo(tensor_table_, call); if (auto err = EthosnAPI::Relu(call, ¶ms)) { return err; } auto input = operand_table_[call->args[0]][0]; try { *out = AddRelu(network_, *input, params.relu_info); } catch (const sl::NotSupportedException& e) { return EthosnError(e.what()); } return EthosnError(); } runtime::Module EthosnCompiler::CreateRuntimeModule(const ObjectRef& ref) { std::vector cmms; if (ref->IsInstance()) { IRModule mod; Function func = Downcast(ref); auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "Failed to retrieved external symbol."; GlobalVar gvar = GlobalVar(name_node.value()); mod->Add(gvar, func); Function mod_func = Downcast(mod->functions.at(gvar)); cmms.emplace_back(CompileEthosnFunc(mod, gvar, mod_func)); } else { LOG(FATAL) << "The input ref is expected to be a Relay function"; } auto n = make_object(&cmms); return runtime::Module(n); } runtime::ethosn::OrderedCompiledNetwork EthosnCompiler::CompileEthosnFunc(const IRModule& mod, const GlobalVar& gvar, const Function& func) { // Construct the network auto network_with_ids = ConstructNetwork(mod, gvar, func); // Now set the required build flags sl::CompilationOptions options = CreateOptions(); // Finally compile the network std::vector> compiled_networks = sl::Compile(*network_with_ids.network, options); ICHECK_GE(compiled_networks.size(), 1) << "Ethos-N compiler failed to compile network"; auto compiled_network = std::move(compiled_networks[0]); // Determine the order that the inputs/outputs are in and how that corresponds to the // order that the TVM runtime will expect them in auto input_output_order = GetInputOutputOrder(network_with_ids, compiled_network); // Use the order information to create an 'ordered' network with includes how to map // the inputs/outputs from the TVM runtime to the inputs/outputs of the compiled network runtime::ethosn::OrderedCompiledNetwork ordered_network; ordered_network.name = gvar->name_hint; ordered_network.compiled_cmm = std::move(compiled_network); ordered_network.inputs = input_output_order.first; ordered_network.outputs = input_output_order.second; return ordered_network; } sl::CompilationOptions EthosnCompiler::CreateOptions() { auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.ethos-n.options"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } sl::CompilationOptions options; options.m_Strategy0 = cfg.value()->strategy0; options.m_Strategy1 = cfg.value()->strategy1; options.m_Strategy3 = cfg.value()->strategy3; options.m_Strategy4 = cfg.value()->strategy4; options.m_Strategy6 = cfg.value()->strategy6; options.m_Strategy7 = cfg.value()->strategy7; options.m_DebugInfo.m_DumpRam = cfg.value()->dump_ram; options.m_DebugInfo.m_InitialSramDump = cfg.value()->initial_sram_dump; options.m_BlockConfig16x16 = cfg.value()->block_config_16x16; options.m_BlockConfig32x8 = cfg.value()->block_config_32x8; options.m_BlockConfig8x32 = cfg.value()->block_config_8x32; options.m_BlockConfig8x8 = cfg.value()->block_config_8x8; options.m_EnableIntermediateCompression = cfg.value()->enable_intermediate_compression; options.m_DisableWinograd = cfg.value()->disable_winograd; options.m_DebugInfo.m_DebugDir = cfg.value()->debug_dir; options.m_CompilerAlgorithm = sl::EthosNCompilerAlgorithmFromString(cfg.value()->compiler_algorithm.c_str()); return options; } std::pair, std::vector> EthosnCompiler::GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network) { std::vector input_infos = compiled_network->GetInputBufferInfos(); std::vector output_infos = compiled_network->GetOutputBufferInfos(); std::vector input_order; std::vector output_order; // Find the order of the inputs in the compiled network for (const auto& input_info : input_infos) { input_order.push_back(network.input_ids[input_info.m_SourceOperationId]); } // Find the order of the outputs in the compiled network for (const auto& output_info : output_infos) { auto output_id = std::make_pair(output_info.m_SourceOperationId, output_info.m_SourceOperationOutputIndex); output_order.push_back(network.output_ids[output_id]); } return std::make_pair(input_order, output_order); } std::unique_ptr EthosnCompiler::m_Queries; EthosnError EthosnCompiler::SupportedSetup() { if (m_Queries == nullptr) { auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.ethos-n.options").defined() ? ctx->GetConfig("relay.ext.ethos-n.options") : AttrsWithDefaultValues(); m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( sl::EthosNVariantFromString(MakeVariant(cfg).c_str()), std::stoul(cfg.value()->sram_size))); if (m_Queries == nullptr) { return EthosnError("Could not initialise Arm(R) Ethos(TM)-N compiler isSupported"); } } return EthosnError(); } TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); err += EthosnCompiler::SupportedSetup(); if (params.is_depthwise) { *rv = !err && EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported( params.bias_info, params.weights_info, params.conv_info, params.activation_info); } else { *rv = !err && EthosnCompiler::GetSupported()->IsConvolutionSupported( params.bias_info, params.weights_info, params.conv_info, params.activation_info); } }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported( params.bias_info, params.weights_info, params.fc_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsReshapeSupported(params.new_shape, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported( params.lhs_info, params.rhs_info, params.output_quantization_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(params.input_infos, params.concat_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsSplitSupported(params.input_info, params.split_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(params.input_info, params.depth_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); err += EthosnCompiler::SupportedSetup(); *rv = !err && EthosnCompiler::GetSupported()->IsReluSupported(params.relu_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { #if defined ETHOSN_HW *rv = true; #else *rv = false; #endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.api.version").set_body_typed([]() -> int { return _ETHOSN_API_VERSION_; }); } // namespace ethosn } // namespace contrib } // namespace relay } // namespace tvm