/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file eta_expand.cc * * \brief Add an abstraction over constructors and/or global variables bound to a function. * */ #include #include #include #include namespace tvm { namespace relay { namespace eta_expand { /*! * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality */ class TypeVarReplacer : public TypeMutator { public: TypeVarReplacer() : replace_map_({}) {} Type VisitType_(const TypeVarNode* type_var_node) final { const auto type_var = GetRef(type_var_node); if (replace_map_.find(type_var) == replace_map_.end()) { replace_map_[type_var] = TypeVar("A", Kind::kType); } return replace_map_[type_var]; } private: /*! \brief variable replacement map to remap old type vars to fresh ones */ std::unordered_map replace_map_; }; /*! * \brief mutator to perform eta expansion on all functions in a module */ class EtaExpander : public ExprMutator { public: explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var) : mod_(mod), type_var_replacer_(TypeVarReplacer()), expand_constructor_(expand_constructor), expand_global_var_(expand_global_var) { ICHECK(expand_constructor || expand_global_var) << "must expand at least one language feature"; } IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); if (auto* n = base_func.as()) { const Function new_func = Downcast(VisitExpr(GetRef(n))); mod_->Update(global_var, new_func); } } return mod_; } Expr VisitExpr_(const CallNode* call) final { // we don't need to expand constructors when they are being called, so we // prevent them being visited here Expr new_op = call->op; if (!call->op.as()) { new_op = VisitExpr(new_op); } tvm::Array new_args; for (const auto& arg : call->args) { new_args.push_back(VisitExpr(arg)); } return Call(new_op, new_args, call->attrs, call->type_args); } Expr VisitExpr_(const ConstructorNode* cons_node) final { Constructor cons = GetRef(cons_node); if (!expand_constructor_) { return std::move(cons); } // NOTE: we only reach this case if the constructor is not being applied to any arguments tvm::Array params; for (const auto& type : cons->inputs) { Type param_type = type_var_replacer_.VisitType(type); params.push_back(Var("eta_expand_param", param_type)); } tvm::Array type_params; TypeData adt_def = mod_->LookupTypeDef(cons->belong_to); for (const auto& type_var : adt_def->type_vars) { type_params.push_back(type_var_replacer_.VisitType(type_var)); } Expr body = Call(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); return Function(Downcast>(params), body, ret_type, Downcast>(type_params)); } Expr VisitExpr_(const GlobalVarNode* gvar_node) final { GlobalVar gvar = GetRef(gvar_node); if (!expand_global_var_) { return std::move(gvar); } const auto base_func = mod_->Lookup(gvar); if (auto* ptr = base_func.as()) { // handle relay function, skip external functions. auto func = GetRef(ptr); tvm::Array params; tvm::Array args; for (size_t i = 0; i < func->params.size(); ++i) { auto var = Var("eta_expand_param", func->params[i]->type_annotation); params.push_back(var); args.push_back(var); } return Function(args, Call(gvar, params), func->ret_type, func->type_params); } else { return std::move(gvar); } } private: /*! \brief reference to module being expanded */ const IRModule mod_; /*! \brief type variable replacer */ TypeVarReplacer type_var_replacer_; /*! \brief whether to expand constructor nodes */ bool expand_constructor_; /*! \brief whether to expand global variable nodes */ bool expand_global_var_; }; } // namespace eta_expand namespace transform { Pass EtaExpand(bool expand_constructor, bool expand_global_var) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); }; return CreateModulePass(pass_func, 1, "EtaExpand", {}); } TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand); } // namespace transform } // namespace relay } // namespace tvm