/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file env_func.cc */ #include #include #include namespace tvm { using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "EnvFunc(" << op->name << ")"; }); ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); ICHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; ObjectPtr n = make_object(); n->func = *f; n->name = name; return n; } EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) { EnvFunc env = args[0]; ICHECK_GE(args.size(), 1); env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv); }); TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) { return n->func; }); TVM_REGISTER_NODE_TYPE(EnvFuncNode) .set_creator(CreateEnvNode) .set_repr_bytes([](const Object* n) -> std::string { return static_cast(n)->name; }); } // namespace tvm