/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/transforms/dead_code.cc * \brief Elides or inlines let-bindings. * * TODO(mbs): Track dead writes into references. */ #include #include #include #include #include "../op/call/call.h" namespace tvm { namespace relay { namespace { /*! \brief Maximum depth of calls to analyize. */ constexpr int kMaxCallDepth = 25; /*! * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression * can safely be elided (if its result is never used), or inlined (which may change the * number of times and program order for the evaluation.) */ struct Purity { /*! * \brief True if evaling the sub-expression itself is pure. */ bool pure_eval; /*! * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling * a call to the the sub-expression is pure. See [RULE A] below. */ bool pure_call; }; /*! * \brief Visits all the global functions in a module and records the purity of every let-bound * value. * * (See also inline.cc for function inlining.) * * Generally we track whether evaluation of a sub-expression is definitely pure. However for * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f: * - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is * also call-pure. * - [RULE B] Higher-order function arguments are assumed call impure. * - [RULE C] We assume functions extracted from tuples are call impure. * - [RULE D] We assume functions extracted from references are call impure. * - [RULE E] We assume functions extracted from ADTs are call impure. * - [RULE F] We assume all external Functions and PrimFuncs are call impure. */ class PurityVisitor : ExprFunctor { public: explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {} /*! \brief Visit all the functions in the module. */ void VisitModule() { VLOG_CONTEXT << "PurityVisitor"; // It is safe to visit the global functions in any order. Recursive global functions are // allowed. for (const auto& kv : mod_->functions) { if (const auto* function_node = kv.second.as()) { if (function_node->HasNonzeroAttr(attr::kPrimitive) || function_node->GetAttr(attr::kExternalSymbol)) { // Ignore primitive and external functions. continue; } // Everything of interest will be recorded in the purity maps so we ignore the result. (void)VisitGlobalFunction(kv.first, GetRef(function_node)); } } } /*! * \brief Returns a map from every let-bound variable to whether its let-bound value is * definitely pure. */ std::unordered_map GetPurityMap() const { std::unordered_map result; for (const auto& kv : var_to_purity_) { result.emplace(kv.first, kv.second.pure_eval); } return result; } private: Purity VisitExpr(const Expr& expr) final { auto it = memo_.find(expr.get()); if (it != this->memo_.end()) { return it->second; } else { Purity result = ExprFunctor::VisitExpr(expr); memo_[expr.get()] = result; return result; } } Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; } Purity VisitExpr_(const ConstructorNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; } Purity VisitExpr_(const OpNode* op_node) final { // Primitive operators are pure unless marked as 'stateful'. static OpAttrMap attr_map = Op::GetAttrMap("TOpIsStateful"); bool is_stateful = attr_map.count(GetRef(op_node)) && attr_map[GetRef(op_node)]; return {/*pure_eval=*/true, /*pure_call=*/!is_stateful}; } Purity VisitExpr_(const GlobalVarNode* global_var_node) final { auto global_var = GetRef(global_var_node); auto func = mod_->Lookup(global_var); if (const auto* function_node = func.as()) { if (!function_node->GetAttr(attr::kExternalSymbol)) { return VisitGlobalFunction(global_var, GetRef(function_node)); } } // Assume externals and PrimFuncs are call-impure [RULE F]. // (If they are pure then we should have dealt with them before lowering.) return {/*pure_eval==*/true, /*pure_call=*/false}; } Purity VisitExpr_(const VarNode* var_node) final { // The var is bound to a value, but if that value is a function we need to propagate the // function body's purity. ICHECK(var_to_purity_.count(var_node)) << PrettyPrint(GetRef(var_node)); return {/*pure_eval=*/true, /*pure_call=*/var_to_purity_[var_node].pure_call}; } Purity VisitExpr_(const FunctionNode* function_node) final { for (const auto& param : function_node->params) { // Any higher-order parameters are assumed to be call-impure [RULE B] var_to_purity_[param.get()] = {/*pure_eval=*/true, /*pure_call=*/IsFirstOrder(param)}; } Purity body_purity = VisitExpr(function_node->body); // The function itself is a value and thus pure. If the function returns // a function we'll fold its purity in here [RULE A] return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval && body_purity.pure_call}; } Purity VisitExpr_(const LetNode* let_node) final { Expr expr = GetRef(let_node); bool all_values_pure_eval = true; while (const auto* inner_let_node = expr.as()) { // In case the value is a recursive function assume the let-bound variable is call-pure. var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; Purity value_purity = VisitExpr(inner_let_node->value); // Now revise the variable to it's true purity. var_to_purity_[inner_let_node->var.get()] = value_purity; VLOG(2) << (value_purity.pure_eval ? "pure" : "impure") << " expression:" << std::endl << PrettyPrint(inner_let_node->value) << std::endl << "let-bound to variable:" << std::endl << PrettyPrint(inner_let_node->var); all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval; expr = inner_let_node->body; } Purity body_purity = VisitExpr(expr); return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval, /*pure_call=*/body_purity.pure_call}; } Purity VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); if (current_call_depth_ >= kMaxCallDepth) { // Assume impure. VLOG(2) << "assuming call is impure since too deeply nested"; return {/*pure_eval=*/false, /*pure_call*/ IsFirstOrder(call)}; } ++current_call_depth_; // We can work with calls in both pre- and post-lowered form. Call vanilla_call = GetAnyCall(call_node); // Find purity for the callee and the args. Purity callee_purity = VisitExpr(vanilla_call->op); bool all_args_pure_eval = true; for (const auto& arg : vanilla_call->args) { Purity arg_purity = VisitExpr(arg); all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval; } VLOG(2) << (callee_purity.pure_call ? "pure" : "impure") << " call to:" << std::endl << PrettyPrint(vanilla_call->op); ICHECK_GT(current_call_depth_, 0); --current_call_depth_; // If the callee's result is itself a function then by [RULE A] its purity // is given by callee_purity.pure_call. return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval && callee_purity.pure_call, /*pure_call=*/IsFirstOrder(call) || callee_purity.pure_call}; } Purity VisitExpr_(const IfNode* if_node) final { Purity cond_purity = VisitExpr(if_node->cond); ICHECK(cond_purity.pure_call); // conditional is first-order Purity true_purity = VisitExpr(if_node->true_branch); Purity false_purity = VisitExpr(if_node->false_branch); return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval && false_purity.pure_eval, /*pure_call=*/true_purity.pure_call && false_purity.pure_call}; } Purity VisitExpr_(const TupleNode* tuple_node) final { bool all_fields_pure = true; for (const auto& field : tuple_node->fields) { // The call purity of each tuple field is lost [RULE C]. Purity field_purity = VisitExpr(field); if (!field_purity.pure_eval) { all_fields_pure = false; } } return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true}; } Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple); ICHECK(tuple_purity.pure_call); // tuple is first-order // We don't track call purity through tuple fields, so if the result is a function type we // must assume it is call impure [RULE C]. return {/*pure_eval=*/tuple_purity.pure_eval, /*pure_call=*/IsFirstOrder(GetRef(tuple_get_item_node))}; } Purity VisitExpr_(const RefCreateNode*) final { // The creation of the ref itself is unobservable other than via the reads/writes into it. return {/*pure_eval=*/true, /*pure_call=*/true}; } Purity VisitExpr_(const RefWriteNode* ref_write_node) final { Purity ref_purity = VisitExpr(ref_write_node->ref); ICHECK(ref_purity.pure_call); // reference is first-order // The call purity of the written value is lost [RULE D]. // (But we must still visit to accumulate purity for any let-bindings within in.) (void)VisitExpr(ref_write_node->value); return {/*pure_eval=*/false, /*pure_call=*/true}; } Purity VisitExpr_(const RefReadNode* ref_read_node) final { Purity ref_purity = VisitExpr(ref_read_node->ref); ICHECK(ref_purity.pure_call); // reference is first-order // We don't track call purity through reference values, so if the result is a function // type we must assume it is call impure [RULE D]. return {/*pure_eval=*/false, /*pure_call=*/IsFirstOrder(GetRef(ref_read_node))}; } class PurityPatternVisitor : public PatternVisitor { public: explicit PurityPatternVisitor(PurityVisitor* outer) : outer_(outer) {} private: void VisitPattern_(const PatternVarNode* pattern_var_node) final { // We don't track call purity through ADTs, so if var is a function type we must assume // it is call impure [RULE E]. outer_->var_to_purity_[pattern_var_node->var.get()] = { /*pure_eval=*/true, /*pure_call=*/IsFirstOrder(pattern_var_node->var)}; } /*! \brief (Mutable borrow of) the outer visitor. */ PurityVisitor* outer_; }; Purity VisitExpr_(const MatchNode* match_node) final { Purity data_purity = VisitExpr(match_node->data); ICHECK(data_purity.pure_call); // ADT is first order bool all_clauses_pure_eval = true; bool all_clauses_pure_call = true; for (const auto& clause : match_node->clauses) { PurityPatternVisitor pattern_visitor(this); pattern_visitor.VisitPattern(clause->lhs); Purity rhs_purity = VisitExpr(clause->rhs); all_clauses_pure_eval = all_clauses_pure_eval && rhs_purity.pure_eval; all_clauses_pure_call = all_clauses_pure_call && rhs_purity.pure_call; } return {/*pure_eval=*/data_purity.pure_eval && all_clauses_pure_eval, /*pure_call=*/all_clauses_pure_call}; } /*! \brief Visits \p func bound to global \p var and returns it's purity. */ Purity VisitGlobalFunction(const GlobalVar& var, const Function& func) { VLOG_CONTEXT << "func " << var->name_hint; VLOG(2) << "visiting"; auto itr = global_var_to_purity_.find(var.get()); if (itr != global_var_to_purity_.end()) { // We've already visited the function body. return itr->second; } // We are entering the body of a possibly-recursive global function. Assume it's body is pure. global_var_to_purity_[var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; // Visit the global function for the first time. Purity func_purity = VisitExpr(func); // Update with the true purity. global_var_to_purity_[var.get()] = func_purity; return func_purity; } static bool IsFirstOrder(const Expr& expr) { return expr->checked_type().as() == nullptr; } /*! \brief The module we're analyzing. */ IRModule mod_; /*! * \brief Maps each let-bound and global variable to the purity of the value it is bound to. * If the variable is bound to a function then the purity of saturating that function is also * tracked. * * Note that global_var_to_purity_, and all the 'pure_call' fields, are only needed internally * during the analysis, andonly the var_to_purity_ 'pure_eval' fields are used downstream. */ std::unordered_map var_to_purity_; std::unordered_map global_var_to_purity_; /*! \brief The current call depth. We'll just assume deeply nested calls are impure rather than * spending all that time to check for sure. A deeply nested call is almost certain to be needed * anyway. */ int current_call_depth_; /*! \brief Internal map used for memoization. */ std::unordered_map memo_; }; /*! * \brief Accumulate the bound values and usage count for each let-bound variable. * * We don't attempt to track the number of calls to local functions, and instead just assume they * are called at least twice. */ class UsageVisitor : public ExprVisitor { public: /*! \brief Accumulates the expression bound to every let-bound variable. */ std::unordered_map let_bound_values_; /*! \brief Accumulates the usage count for every let-bound variable. */ std::unordered_map use_map_; explicit UsageVisitor(const std::unordered_map* var_to_purity, bool default_purity) : var_to_purity_(var_to_purity), default_purity_(default_purity) {} void VisitExpr(const Expr& expr) final { // Once we've seen 2 usages of a variable we know it can be neither elided nor inlined, // so can stop visiting again. if (++visit_counter_[expr.get()] <= 2) { ExprFunctor::VisitExpr(expr); } } void VisitExpr_(const FunctionNode* function_node) final { ++current_scope_level_; ExprVisitor::VisitExpr_(function_node); ICHECK_GT(current_scope_level_, 0); --current_scope_level_; } void VisitExpr_(const LetNode* let_node) final { Expr expr = GetRef(let_node); while (const auto* inner_let_node = expr.as()) { ++visit_counter_[inner_let_node]; let_bound_values_[inner_let_node->var.get()] = inner_let_node->value; VLOG(2) << "seen let-binding for:" << std::endl << PrettyPrint(inner_let_node->var); use_map_[inner_let_node->var.get()] = 0; scope_level_map_[inner_let_node->var.get()] = current_scope_level_; if (is_pure(inner_let_node->var.get())) { // We'll defer visiting the let-bound value until we've seen the first use of the let-bound // variable and thus know it must be evaluated. // no-op. } else { // The let-bound value is impure so must always be evaluated. Visit now. VisitExpr(inner_let_node->value); } expr = inner_let_node->body; } VisitExpr(expr); } void VisitExpr_(const VarNode* var_node) final { if (let_bound_values_.count(var_node)) { size_t& n = use_map_[var_node]; ++n; VLOG(2) << var_node->name_hint() << " = " << n; if (n == 1 && is_pure(var_node)) { // Now that we have at least one use of the let-bound var, we know the let-bound // value is necessary. VisitExpr(let_bound_values_[var_node]); } if (scope_level_map_[var_node] < current_scope_level_) { // Since the variable was bound outside of the current local function, assume the // function will be called at least twice. ++n; VLOG(2) << var_node->name_hint() << " = " << n << " (bound at level " << scope_level_map_[var_node] << " but used at level " << current_scope_level_ << ")"; } } // else: nothing to be done for function parameters or variable in match patterns. } bool is_pure(const VarNode* var_node) const { auto itr = var_to_purity_->find(var_node); return itr == var_to_purity_->end() ? default_purity_ : itr->second; } /*! \brief (Immutable borrow of) the already determined purity for every let-bound variable. */ const std::unordered_map* var_to_purity_; /*! \brief The default purity for variables which are not in the above map. */ bool default_purity_; /*! * \brief The current scope level. 0 for global functions. Incremented by one within each * let-bound local function. Necessary so we can avoid inlining an expensive let-bound computation * into a function which could be called more than once. */ int current_scope_level_ = 0; /*! \brief Accumulates the scope level for every let-bound variable. */ std::unordered_map scope_level_map_; }; /*! \brief Eliminate/inline let-bound values when sound to do so. */ class EliminatorMutator : public ExprMutator { public: EliminatorMutator(bool inline_once, const std::unordered_map* let_bound_values, const std::unordered_map* use_map, const std::unordered_map* var_to_purity, bool default_purity) : inline_once_(inline_once), let_bound_values_(let_bound_values), use_map_(use_map), var_to_purity_(var_to_purity), default_purity_(default_purity) {} private: enum Action { kElide, kInline, kNoChange }; /*! \brief What should we do with let-binding for \p var_node? */ Action ActionFor(const VarNode* var_node) { if (let_bound_values_->count(var_node) == 0) { // Not let-bound var. return kNoChange; } if (!is_pure(var_node)) { // The let-bound value is impure -- we must leave it exactly where it is. return kNoChange; } switch (use_map_->count(var_node) ? use_map_->at(var_node) : 0) { case 0: return kElide; case 1: return inline_once_ ? kInline : kNoChange; default: return kNoChange; } } Expr VisitExpr_(const VarNode* var_node) final { if (ActionFor(var_node) == kInline) { VLOG(1) << "inlining let-bound variable:" << std::endl << PrettyPrint(GetRef(var_node)); return VisitExpr(let_bound_values_->at(var_node)); } else { return GetRef(var_node); } } Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { if (ActionFor(op->var.get()) != kElide) { (void)VisitExpr(op->value); } }; auto post_visit = [this](const LetNode* op) { Expr body = VisitExpr(op->body); auto expr = GetRef(op); switch (ActionFor(op->var.get())) { case kElide: VLOG(1) << "eliding let-bound variable:" << std::endl << PrettyPrint(op->var); memo_[expr] = body; break; case kInline: // Already inlined at use-side. memo_[expr] = body; break; case kNoChange: Expr value = VisitExpr(op->value); memo_[expr] = Let(op->var, value, body); break; } }; ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } bool is_pure(const VarNode* var_node) const { auto itr = var_to_purity_->find(var_node); return itr == var_to_purity_->end() ? default_purity_ : itr->second; } bool inline_once_; const std::unordered_map* let_bound_values_; const std::unordered_map* use_map_; const std::unordered_map* var_to_purity_; bool default_purity_; }; } // namespace namespace transform { // Declared in relay/transform.h Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { // Which let bindings are pure and can be safely elided? std::unordered_map var_to_purity; if (!ignore_impurity) { VLOG(1) << "determine purity"; PurityVisitor purity_visitor(mod); purity_visitor.VisitModule(); var_to_purity = purity_visitor.GetPurityMap(); } IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs); for (const auto& kv : mod->functions) { if (const auto* function_node = kv.second.as()) { auto function = GetRef(function_node); VLOG(1) << "processing " << PrettyPrint(kv.first); VLOG(2) << "count usage"; UsageVisitor usage_visitor(&var_to_purity, /*default_purity=*/ignore_impurity); usage_visitor.VisitExpr(function); // Actually eliminate/inline the let-bindings. VLOG(2) << "eliminate"; EliminatorMutator eliminator_mutator(inline_once, &usage_visitor.let_bound_values_, &usage_visitor.use_map_, &var_to_purity, /*default_purity=*/ignore_impurity); result->Add(kv.first, Downcast(eliminator_mutator.VisitExpr(function))); } else { // PrimFuncs come across unchanged. result->Add(kv.first, kv.second); } } return result; }; return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/1, "DeadCodeElimination", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform } // namespace relay } // namespace tvm