/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/analysis/get_calibration_data.cc * * \brief To get the calibration data, we need to perform two * steps. First, we need to prepare the module that generates * the tensor values (GetCalibrateModule). Second, we need to * generate the mapping between the values and the functions * (GetCalibrateOutputMap). */ #include #include #include namespace tvm { namespace relay { /*! * \brief This function returns a module that will be used by * the relay graph executor for collecting the calibration data. * To do that, we first make all inputs and outputs of each * function into the final output (i.e., the final output is a * tuple of tensors). Then, we change the compiler attribute of * each function. Finally, we mark all function to be inlined. */ class Collector : public ExprRewriter { public: explicit Collector(const IRModule& module) : module_(module) {} Expr Rewrite_(const CallNode* call, const Expr& post) final { // check if the function implementation is available // intrinsic functions are excluded for now if (call->op->IsInstance()) { auto var = Downcast(call->op); ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined"; // we only handle functions with Compiler attribute set auto func = Downcast(module_->Lookup(var)); if (func->GetAttr(attr::kCompiler)) { // collect all the inputs and outputs for (const auto& it : call->args) new_outputs_.push_back(it); new_outputs_.push_back(post); } } return post; } Array GetNewOutputs() { return new_outputs_; } private: const IRModule& module_; Array new_outputs_; }; Expr FlattenOutputTuple(const Array& exprs) { Array fields; for (const auto& it : exprs) { ICHECK(it->checked_type_.defined()); if (auto* tn = it->checked_type_.as()) { // TODO(seanlatias): for now input argument cannot be a tuple ICHECK(it->IsInstance()); for (size_t i = 0; i < tn->fields.size(); i++) { fields.push_back(TupleGetItem(it, i)); } } else { fields.push_back(it); } } return Tuple(fields); } IRModule GetCalibrateModule(IRModule module) { auto glob_funcs = module->functions; // module is mutable, hence, we make a copy of it. module.CopyOnWrite(); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); // we only collect the outputs for main function if (pair.first->name_hint == "main") { Collector collector(module); PostOrderRewrite(func->body, &collector); auto new_outputs = collector.GetNewOutputs(); Expr tuple = FlattenOutputTuple(new_outputs); func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs); module->Update(pair.first, func); } } } // reset the attribute of functions for running graph executor for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); if (func->GetAttr(attr::kCompiler)) { // we need to inline the functions in order to run grpah runtime func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1)); // reset the compiler attribute to null for llvm execution func = WithAttr(std::move(func), attr::kCompiler, NullValue()); module->Update(pair.first, func); } } } return module; } /*! * \brief This function generates the output mapping between * the calibration data and each function. The key is a * GlobalVar that corresponds to each function and the value * is an array of integers. The size of the array is always * three. The first value is the offset the points to the start. * The second value is the number of inputs. The third value * is the number of outputs. */ class OutputMapper : public ExprRewriter { public: OutputMapper(Map>* output_map, const IRModule& module, size_t* offset) : output_map_(output_map), module_(module), offset_(offset) {} Expr Rewrite_(const CallNode* call, const Expr& post) final { if (call->op->IsInstance()) { auto var = Downcast(call->op); ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined"; ICHECK_EQ(output_map_->count(var), 0) << "Repeated function call " << var << " is not supported."; auto func = Downcast(module_->Lookup(var)); // we only handle functions with Compiler attribute set if (func->GetAttr(attr::kCompiler)) { Array info; // the first value is the offset info.push_back(Integer(*offset_)); // the second value is the number of inputs info.push_back(Integer(call->args.size())); // the third value is the number of outputs // we need to check if the output is a tuple size_t out_size = 1; if (auto* tn = func->body.as()) { info.push_back(Integer(tn->fields.size())); out_size = tn->fields.size(); } else { info.push_back(Integer(1)); } output_map_->Set(var, info); // calculate the offset for the next function *offset_ = *offset_ + call->args.size() + out_size; } } return post; } private: Map>* output_map_; const IRModule& module_; size_t* offset_; }; Map> GetCalibrateOutputMap(const IRModule& module) { Map> output_map; size_t offset = 0; auto glob_funcs = module->functions; for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { if (pair.first->name_hint == "main") { OutputMapper output_mapper(&output_map, module, &offset); auto func = GetRef(fn); PostOrderRewrite(func->body, &output_mapper); } } } return output_map; } TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module").set_body_typed([](IRModule mod) { return GetCalibrateModule(mod); }); TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map") .set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); }); } // namespace relay } // namespace tvm