/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file simplify_inference.cc */ #include #include #include #include #include #include "pattern_utils.h" namespace tvm { namespace relay { Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, Type tdata) { auto ttype = tdata.as(); ICHECK(ttype); const auto param = attrs.as(); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); Expr var_add_eps = Add(moving_var, epsilon); Expr sqrt_var = Sqrt(var_add_eps); Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); if (param->scale) { scale = Multiply(scale, gamma); } Expr neg_mean = Negative(moving_mean); Expr shift = Multiply(neg_mean, scale); if (param->center) { shift = Add(shift, beta); } auto ndim = ttype->shape.size(); int axis = (param->axis < 0) ? param->axis + ndim : param->axis; scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); Expr out = Multiply(data, scale); out = Add(out, shift); return out; } Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); ICHECK(ttype); const auto param = attrs.as(); ICHECK(param); int ndim = ttype->shape.size(); int axis = (param->axis < 0) ? param->axis + ndim : param->axis; Array reduced_axes; Array new_shape; Array old_shape; int num_groups = param->num_groups; int channel = ttype->shape[axis].as()->value; // old_shape = N, C, H, W // new shape = N, num_groups, C/num_groups, H, W // reduce_axes = axis of (C/num_groups, H, W) for (int i = 0; i < ndim; ++i) { auto val = ttype->shape[i].as()->value; // Save the old shape to reshape later old_shape.push_back(val); if (i == axis) { new_shape.push_back(num_groups); new_shape.push_back(channel / num_groups); reduced_axes.push_back(i + 1); continue; } if (i >= axis) { reduced_axes.push_back(i + 1); } new_shape.push_back(val); } data = Reshape(data, new_shape); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); Expr mean = Mean(data, {reduced_axes}, true, false); Expr var = Variance(data, mean, {reduced_axes}, true, false); Expr denom = Sqrt(Add(var, epsilon)); Expr out = Divide(Subtract(data, mean), denom); out = Reshape(out, old_shape); if (param->scale) { out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); } if (param->center) { out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); } return out; } Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); ICHECK(ttype); const auto param = attrs.as(); ICHECK(param); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); Expr mean = Mean(data, {param->axis}, true, false); Expr var = Variance(data, mean, {param->axis}, true, false); Expr denom = Sqrt(Add(var, epsilon)); Expr out = Divide(Subtract(data, mean), denom); size_t ndim = ttype->shape.size(); int axis = (param->axis < 0) ? param->axis + ndim : param->axis; if (param->scale) { out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); } if (param->center) { out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); } return out; } Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); ICHECK(ttype); const auto param = attrs.as(); ICHECK(param); int ndim = ttype->shape.size(); int axis = (param->axis < 0) ? param->axis + ndim : param->axis; Array reduced_axes; for (int i = 1; i < ndim; ++i) { if (i != axis) reduced_axes.push_back(i); } Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); Expr mean = Mean(data, reduced_axes, true, false); Expr var = Variance(data, mean, reduced_axes, true, false); Expr denom = Sqrt(Add(var, epsilon)); Expr out = Divide(Subtract(data, mean), denom); if (param->scale) { out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); } if (param->center) { out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); } return out; } Expr L2NormToInferUnpack(const Attrs attrs, Expr data) { const auto param = attrs.as(); ICHECK(param); Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast(param->eps)); Expr sqr = Multiply(data, data); Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon); Expr sqrt = Sqrt(sum); return Divide(data, sqrt); } class InferenceSimplifier : public MixedModeMutator { public: InferenceSimplifier() : batch_norm_op_(Op::Get("nn.batch_norm")), dropout_op_(Op::Get("nn.dropout")), instance_norm_op_(Op::Get("nn.instance_norm")), layer_norm_op_(Op::Get("nn.layer_norm")), group_norm_op_(Op::Get("nn.group_norm")), l2_norm_op_(Op::Get("nn.l2_normalize")) {} Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final { const auto* new_n = new_e.as(); if (new_n->index != 0) { return new_e; } if (const auto* call = new_n->tuple.as()) { if (call->op == batch_norm_op_) { return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], call->args[3], call->args[4], ty_map_.at(call->args[0])); } else if (call->op == dropout_op_) { return call->args[0]; } } return new_e; } Expr Rewrite_(const CallNode* n, const Expr& new_n) { if (n->op == batch_norm_op_) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); } else if (n->op == layer_norm_op_) { const auto* call = new_n.as(); return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); } else if (n->op == group_norm_op_) { const auto* call = new_n.as(); return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); } else if (n->op == instance_norm_op_) { const auto* call = new_n.as(); return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); } else if (n->op == l2_norm_op_) { const auto* call = new_n.as(); return L2NormToInferUnpack(call->attrs, call->args[0]); } return new_n; } private: // Cache the following ops. They will be used in the passes repeatedly for // operator equivalence checking so that the registry lookup overhead can be // reduced. const Op& batch_norm_op_; const Op& dropout_op_; const Op& instance_norm_op_; const Op& layer_norm_op_; const Op& group_norm_op_; const Op& l2_norm_op_; std::unordered_map ty_map_; }; Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } namespace transform { Pass SimplifyInference() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyInference(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference); } // namespace transform } // namespace relay } // namespace tvm