/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. */ #include #include #include #include #include #include #include #include #include #include #include #include #include "arg_binder.h" #include "ir_utils.h" namespace tvm { namespace tir { PrimFunc MakeUnpackedAPI(PrimFunc&& func) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute"; auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute"; auto* func_ptr = func.CopyOnWrite(); // Setup device context int target_device_type = target.value()->kind->device_type; Integer device_type(target_device_type); Integer device_id(0); PrimExpr node = StringImm("default"); const Stmt nop = Evaluate(0); std::vector device_init; // Create arg to buffer binder std::unordered_map vmap; ArgBinder binder(&vmap); // Collect variables and buffers to map between Array args; std::vector> var_def; bool buffer_map_found = false; for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; auto it = func_ptr->buffer_map.find(param); if (it != func_ptr->buffer_map.end()) { args.push_back((*it).second->data); buffer_map_found = true; } else { args.push_back(param); } } if (buffer_map_found) { device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop)); device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } func_ptr->body = MergeNest({device_init, binder.init_nest(), binder.asserts()}, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. return std::move(func); } namespace transform { Pass MakeUnpackedAPI() { auto pass_func = [](IRModule m, PassContext ctx) { IRModuleNode* mptr = m.CopyOnWrite(); std::vector> updates; for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakeUnpackedAPI(std::move(func)); updates.push_back({kv.first, updated_func}); } } } for (const auto& pair : updates) { mptr->AddUnchecked(pair.first, pair.second); } return m; }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); } TVM_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); } // namespace transform } // namespace tir } // namespace tvm