/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file constant_folding.cc */ #include #include #include #include #include #include #include #include #include #include #include "../op/memory/on_device.h" #include "./pattern_utils.h" namespace tvm { namespace relay { namespace transform { namespace { /*! * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" * annotation CallNode (which serves only to associate an \p VirtualDevice to the constant and has * no operational effect). */ bool IsSimpleConstant(const Expr& expr) { return AsIgnoringOnDevice(expr) != nullptr; } /*! * \brief Returns whether \p expr \p IsSimpleConstant directly or is a tuple of * \p IsComplexConstant expressions. */ bool IsComplexConstant(const Expr& expr) { if (IsSimpleConstant(expr)) { return true; } else if (const auto* tuple_node = AsIgnoringOnDevice(expr)) { return std::all_of(tuple_node->fields.begin(), tuple_node->fields.end(), IsComplexConstant); } else { return false; } } // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. class ConstantFolder : public MixedModeMutator { public: explicit ConstantFolder(IRModule module) : module_(std::move(module)), device_copy_op_(Op::Get("device_copy")), shape_of_op_(Op::Get("shape_of")), vm_shape_of_op_(Op::Get("vm.shape_of")), cast_op_(Op::Get("cast")), ndarray_size_op_(Op::Get("ndarray_size")) {} private: using ExprMutator::VisitExpr_; Expr VisitExpr_(const LetNode* let_node) final { auto pre_visit = [this](const LetNode* op) { // Rely on the Memoizer to cache pre-visit values Expr new_value = Mutate(op->value); if (IsSimpleConstant(new_value)) { // Inline new value (along with any on_device annotation wrapping it) at all occurrences of // the variable. // // We need to retain any "on_device" annotation so that downstream 'device aware' // passes can still retrieve the virtual device for the constant in its new position(s). Eg: // def @f(..., result_virtual_device=D) { // let %x = on_device(... something we eval to a constant..., virtual_device=E) // @f(..., %x, ...) // } // Here the default virtual device is D, whereas the argument %x to @f is on E (and @f // expects that). No on_device annotation is required in the call according to the // convention used by the device-aware visitors. // // However once we've inlined the constant we need to insert an on_device, again to // respect the convention used by the device-aware visitors. // def @f(..., result_virtual_device=D) { // @f(..., on_device(...the constant..., virtual_device=E), ...) // } VLOG(1) << "Replacing let-binding for " << op->var->name_hint() << " with constant:" << std::endl << PrettyPrint(new_value); memo_[op->var] = new_value; } else { this->Mutate(op->var); } }; auto post_visit = [this](const LetNode* op) { Expr expr = GetRef(op); // Rely on the Memoizer to cache pre-visit values Expr new_value = this->Mutate(op->value); if (IsSimpleConstant(new_value)) { // The let-bound value has been inlined, drop the let-binding itself. this->memo_[expr] = Mutate(op->body); } else { Var new_var = Downcast(this->Mutate(op->var)); Expr new_body = this->Mutate(op->body); if (new_var.same_as(op->var) && new_value.same_as(op->value) && new_body.same_as(op->body)) { this->memo_[expr] = expr; } else { this->memo_[expr] = Let(new_var, new_value, new_body, op->span); } } }; ExpandANormalForm(let_node, pre_visit, post_visit); return memo_[GetRef(let_node)]; } Expr VisitExpr_(const FunctionNode* function_node) final { if (function_node->HasNonzeroAttr(attr::kPrimitive)) { ICHECK_EQ(inside_primitive_, false); inside_primitive_ = true; auto ret = ExprMutator::VisitExpr_(function_node); inside_primitive_ = false; return ret; } else { return ExprMutator::VisitExpr_(function_node); } } Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { Call pre_call = GetRef(pre_call_node); if (inside_primitive_) { return std::move(pre_call); } Call post_call = Downcast(post); if (post_call->args.empty()) { // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). return std::move(pre_call); } static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); const auto* op_node = post_call->op.as(); if (op_node == nullptr) { // Only evaluate primitives. return std::move(post_call); } Op op = GetRef(op_node); static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); if (op_stateful.get(op, false)) { // skip stateful ops. return std::move(post_call); } // Try to evaluate shape_of and ndarray_size ops // Use the original call rather than new_call here since it still has valid checked_type // fields. These operators don't care about the value of their argument anyway. if (Optional opt_result = EvaluateShapeOf(pre_call)) { return opt_result.value(); } // Use the original call rather than new_call here since it still has valid checked_type // fields. This operator doesn't care about the value of its argument anyway. if (Optional opt_result = EvaluateNdarraySize(pre_call)) { return opt_result.value(); } if ((fnoncomputational.count(op) && fnoncomputational[op]) || op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ || op == ndarray_size_op_) { // We should think about potentially constant evaluation over these ops too. return std::move(post_call); } if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) { // At least one non-constant argument. return std::move(post_call); } // During evaluation we have obviously lost all on_device annotations. However any // on_device wrapping this call will be left in place. return ConstEvaluate(post_call); } Expr VisitExpr_(const IfNode* if_node) final { If new_if = Downcast(ExprMutator::VisitExpr_(if_node)); if (const auto* const_node = AsIgnoringOnDevice(new_if->cond)) { if (reinterpret_cast(const_node->data->data)[0]) { return new_if->true_branch; } else { return new_if->false_branch; } } return std::move(new_if); } Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post_tuple_get_item) final { const auto* post_tuple_get_item_node = post_tuple_get_item.as(); if (const auto* tuple_node = AsIgnoringOnDevice(post_tuple_get_item_node->tuple)) { Expr result = tuple_node->fields[tuple_get_item_node->index]; OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D) return MaybeOnDeviceWithProps(result, props); } else { return result; } } return post_tuple_get_item; } // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { if (value->IsInstance()) { auto nd_array = Downcast(value); return Constant(nd_array); } else if (const auto* val = value.as()) { runtime::ADT adt = GetRef(val); Array fields; for (size_t i = 0; i < adt.size(); ++i) { fields.push_back(ObjectToExpr(adt[i])); } return Tuple(fields); } else { LOG(FATAL) << "Cannot handle " << value->GetTypeKey(); return {}; } } // Constant evaluate an expression. Expr ConstEvaluate(const Expr& expr) { VLOG_CONTEXT << "ConstEvaluate"; VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr); // We'll invoke the interpreter using the generic CPU device and target. Technically there's // no guarantee the results will be bitwise equal what we'd get on the true device, however to // support cross-compilation we don't want to assume the true device is available. // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); Expr result = ObjectToExpr( Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); return result; } /*! * \brief Returns constant shape result of \p call if it of form \p shape_of(e) and \p e has * a non-dynamic tensor shape. Returns null otherwise. */ Optional EvaluateShapeOf(const Call& call) { if (call->op != shape_of_op_ && call->op != vm_shape_of_op_) { return {}; } VLOG(1) << "Evaluating for shape_of:" << std::endl << PrettyPrint(call); ICHECK_EQ(call->args.size(), 1); const auto* param = call->attrs.as(); ICHECK(param != nullptr); Expr input = call->args[0]; tvm::Array ishape; if (Optional> opt_shape = GetConstantShape(input)) { ishape = opt_shape.value(); } else { return {}; } // Get the constant shape runtime::NDArray value; DLDataType cdtype = DataType::Int(32); if (ishape.empty()) { value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); } else { ICHECK_NE(ishape.size(), 0); std::vector cshape = {static_cast(ishape.size())}; value = runtime::NDArray::Empty(cshape, cdtype, eval_cpu_dev_); auto* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { if (const auto* dim = ishape[i].as()) { dims[i] = dim->value; } else { return {}; } } } Constant shape = Downcast(ObjectToExpr(value)); if (shape->data.Shape().empty() && GetScalarFromConstant(shape) == 0) { auto ndarray = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); shape = Constant(ndarray); } return CastValue(shape, param->dtype); } /*! * \brief Returns the constant NDArray size of result of \p call if it is of the form * \p ndarray_size(e) and \p e has non-dynamic tensor type. Returns null otherwise. */ Optional EvaluateNdarraySize(const Call& call) { if (call->op != ndarray_size_op_) { return {}; } VLOG(1) << "Evaluating for ndarray_size:" << std::endl << PrettyPrint(call); ICHECK_EQ(call->args.size(), 1); Expr input = call->args[0]; const auto* param = call->attrs.as(); ICHECK(param != nullptr); tvm::Array ishape; if (Optional> opt_shape = GetConstantShape(input)) { ishape = opt_shape.value(); } else { return {}; } // Get the constant size runtime::NDArray value; DLDataType cdtype = DataType::Int(32); value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); auto* data = static_cast(value->data); if (ishape.empty()) { *data = 0; } else { *data = 1; using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { if (const auto* dim = ishape[i].as()) { *data *= dim->value; } else { return {}; } } } Constant size = Downcast(ObjectToExpr(value)); return CastValue(size, param->dtype); } Expr CastValue(const Expr& value, DataType dtype) { // Cast the constant into correct dtype auto cast_attrs = make_object(); cast_attrs->dtype = dtype; Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } Optional> GetConstantShape(const Expr& input) { if (const auto* const_node = AsIgnoringOnDevice(input)) { // TODO(mbs): This is not necessary since we only ever ask for the shapes for // pre-rewritten expressions which will always have a checked_type. return const_node->tensor_type()->shape; } else if (input->checked_type_.defined()) { return input->checked_type().as()->shape; } else { return {}; } } // Module IRModule module_; // The kDLCPU device assumed to be available to the compiler. Used only when evaluating // sub-expressions. Device eval_cpu_dev_{kDLCPU, /*device_id=*/0}; // The target for the above device assumed to be available to the compiler. Used only when // evaluating sub-expressions. Target eval_cpu_target_{"llvm"}; // Cache the following ops for equivalence checking in this pass. const Op& device_copy_op_; const Op& shape_of_op_; const Op& vm_shape_of_op_; const Op& cast_op_; const Op& ndarray_size_op_; // True if currently within a "primitive" Relay Function. bool inside_primitive_ = false; }; } // namespace TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexConstant); /*! * \brief Returns \p expr with any constants expressions evaluated and let-bound constants * inlined. Returns \p expr unchanged if no change. * * CAUTION: The importers rely on this function returning \p expr unchanged to preserve sharing * from their p.o.v. Furthermore, this function can be called before conversion to ANF so * we must avoid all recursion. */ Expr FoldConstantExpr(const Expr& expr, const IRModule& mod) { VLOG_CONTEXT << "FoldConstantExpr"; VLOG(1) << "folding:" << std::endl << PrettyPrint(expr); Expr result = ConstantFolder(mod).VisitExpr(expr); VLOG(1) << "folded to:" << std::endl << PrettyPrint(result); return result; } TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstantExpr); Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(FoldConstantExpr(f, m)); }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform } // namespace relay } // namespace tvm