/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/backend/contrib/verilator/codegen.cc * \brief Implementation of Verilator codegen APIs. */ #include <tvm/relay/attrs/nn.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/transform.h> #include <tvm/relay/type.h> #include <tvm/runtime/module.h> #include <tvm/runtime/registry.h> #include <fstream> #include <numeric> #include <sstream> #include "../../../../runtime/contrib/json/json_node.h" #include "../../../../runtime/contrib/verilator/verilator_runtime.h" #include "../../utils.h" #include "../codegen_json/codegen_json.h" namespace tvm { namespace relay { namespace contrib { using namespace backend; /*! \brief Verilator JSON serializer */ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; public: VerilatorJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override { Expr expr = GetRef<Expr>(cn); std::string name; const CallNode* call = cn; if (const auto* op_node = cn->op.as<OpNode>()) { name = op_node->name; } else { LOG(FATAL) << "Verilator JSON runtime does not support calls to " << cn->op->GetTypeKey(); } std::vector<JSONGraphNodeEntry> inputs; for (const auto& arg : cn->args) { auto res = VisitExpr(arg); inputs.insert(inputs.end(), res.begin(), res.end()); } auto node = std::make_shared<JSONGraphNode>(name, /* name_ */ "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, call); return AddNode(node, GetRef<Expr>(cn)); } }; /*! \brief Attributes to store options for Verilator */ struct VerilatorOptionsNode : public tvm::AttrsNode<VerilatorOptionsNode> { String lib_path; int reset_cycles; bool profiler_enable; int profiler_cycle_counter_id; TVM_DECLARE_ATTRS(VerilatorOptionsNode, "ext.attrs.VerilatorOptionsNode") { TVM_ATTR_FIELD(lib_path).describe("the design library path").set_default("libverilator.so"); TVM_ATTR_FIELD(reset_cycles).describe("the number of reset cycles").set_default(1); TVM_ATTR_FIELD(profiler_enable).describe("enable profiler").set_default(false); TVM_ATTR_FIELD(profiler_cycle_counter_id).describe("profiler cycle counter id").set_default(0); } }; class VerilatorOptions : public Attrs { public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorOptions, Attrs, VerilatorOptionsNode); }; TVM_REGISTER_NODE_TYPE(VerilatorOptionsNode); TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions); /*! * \brief The Verilator codegen tool. It takes a Relay expression/module and * compile it into a Verilator runtime module. */ runtime::Module VerilatorBackend(const ObjectRef& ref) { VLOG(0) << "compiling for verilator runtime"; CHECK(ref->IsInstance<FunctionNode>()); auto func = Downcast<Function>(ref); auto func_name = GetExtSymbol(func); VerilatorJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); auto params = serializer.GetParams(); // Create runtime object auto n = make_object<runtime::contrib::VerilatorRuntime>(func_name, graph_json, params); // Get Verilator compiler options auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig<VerilatorOptions>("relay.ext.verilator.options"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues<VerilatorOptions>(); } n->SetLibrary(cfg.value()->lib_path); n->SetResetCycles(cfg.value()->reset_cycles); if (cfg.value()->profiler_enable) { n->EnableProfiler(); n->SetProfilerCycleCounterId(cfg.value()->profiler_cycle_counter_id); } return runtime::Module(n); } TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorBackend); } // namespace contrib } // namespace relay } // namespace tvm