/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file fold_scale_axis.cc * * \brief Fold axis scaling into weights of * conv/dense operators. */ #include #include #include #include #include #include "../op/tensor/transform.h" #include "pass_utils.h" #include "pattern_utils.h" namespace tvm { namespace relay { /*! * \brief namespace of fold scale axis * * Use namespace to reduce potential naming conflict. */ namespace fold_scale_axis { using runtime::TypedPackedFunc; // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of // (value, axes, scale), where the final result satisfies: // // result = value // for i, k in enumerate(axes): // k-th dimension of result *= i-th dimension of scale // // Then we can propagate this signal along and fold the scale if necessary. // However, it is possible that certain scale may never be consumed // if there is no dense/conv2d that follows multiplication. // // In order to make sure all the scale we sent out can be consumed eventually, // we run a backward "preparation phase", which propagates the demand // of the potential axes scaling back to its input. // // Forward folding process is done in two steps: // - Prepare phase: backward propagation of demand. // - Transform phase: forward transformation, // // Similarly, backward folding process is done in two steps: // - Prepare phase: forward propagation of demand. // - Transform phase: transformation by push down the axes scale signal to inputs. // /*! * \brief sorted array axis, can also be nullptr. * * nullptr means no scaling request can be done. */ using AxesSet = Array; class Message; /*! * \brief Message propogated during the prepare phase. */ class MessageNode : public RelayNode { public: /*! \brief Axes for scaling */ AxesSet axes; /*! * \brief Whether folding requires the scale to be positive constant. This is necessary if some * operators (e.g. Relu) is present. */ bool require_positive; static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode); }; class Message : public ObjectRef { public: /*! * \brief The constructor * \param axes Axes for scaling * \param require_positive If folding requires the scales to be positive * values. */ Message(const AxesSet& axes, bool require_positive); TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); }; Message::Message(const AxesSet& axes, bool require_positive) { auto n = make_object(); n->axes = axes; n->require_positive = require_positive; data_ = std::move(n); } /*! * \brief Merge two axis set together by taking * intersection. * * \note The axes in a AxesSet should be sorted. * * \param lhs The left axis. * \param rhs The right axis. * \return The result of the inersection. */ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { if (!lhs.defined()) return lhs; if (!rhs.defined()) return rhs; // This code relies on axes in a AxesSet to be sorted. AxesSet ret; size_t i = 0, j = 0; while (i < lhs.size() && j < rhs.size()) { if (lhs[i]->value < rhs[j]->value) { ++i; } else if (lhs[i]->value > rhs[j]->value) { ++j; } else { ret.push_back(lhs[i]); ++i; ++j; } } return ret; } /*! * \brief Merge two messages together by taking intersection. * * \param lhs The lhs message. * \param rhs The rhs message. * \return The result of intersection. */ Message Intersect(const Message& lhs, const Message& rhs) { if (!lhs.defined()) return lhs; if (!rhs.defined()) return rhs; auto axes = Intersect(lhs->axes, rhs->axes); return Message(axes, lhs->require_positive || rhs->require_positive); } /*! * \brief Preparation function for pass scale forward. * \param call The call node. * \param out_message Message from the output containing possible scaling on axes and whether * positive scale is required. * \return The message containing the result scaling on axes of the input. */ using FForwardPrep = runtime::TypedPackedFunc(const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { public: /*! \brief The value */ Expr value; /*! \brief The axes to scale, can be nullptr(means no-scaling) */ AxesSet axes = NullValue(); /*! \brief The scaling factor */ Expr scale = NullValue(); Expr Realize() const final { ICHECK(!axes.defined()) << "outstanding scale"; return value; } void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); v->Visit("axes", &axes); v->Visit("scale", &scale); } static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; using FForwardRewrite = TypedPackedFunc& new_args, const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private MixedModeVisitor { public: std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order // which is a special case of topological order. // We reversely traverse the list to invoke the lazy functions. // This act like a backprop of valid scale axis messages for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) { (*it)(); } // return the created message; return std::move(message_); } private: // The invoke list std::vector> flist_; // The message on each node. std::unordered_map message_; // Update the message stored at node. void Update(const Expr& node, const Message& message) { // We run intersection of messages: // // %y = multiply(%x, %scale) // %z1 = conv2d(%y, %w) // %z2 = exp(%y) // // Consider the above code example, // because %z2 will propagate null to %y, // the AxesSet on %y is also null, // and the forward folding won't be triggered. const Object* key = node.get(); if (message_.count(key)) { message_[key] = Intersect(message_[key], message); } else { message_[key] = message; } } // We intended the following overrides on implementations from ExprVisitor. using MixedModeVisitor::VisitExpr_; // Visitor pattern override. void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); } void VisitExpr_(const LetNode* op) final { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue // it means fuse signal cannot pass // through into these subexpressions. auto flazy = [this, op]() { this->Update(op->value, NullValue()); this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } void VisitExpr_(const FunctionNode* op) final { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { static const auto& fprep = Op::GetAttrMap("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; if (it != message_.end()) { out_message = it->second; } else { out_message = NullValue(); } // pass the message back to all the children it references. auto f = fprep.get(call->op, nullptr); if (f != nullptr) { Array in_messages = f(GetRef(call), out_message); ICHECK_EQ(in_messages.size(), call->args.size()); for (size_t i = 0; i < call->args.size(); ++i) { this->Update(call->args[i], in_messages[i]); } } else { for (size_t i = 0; i < call->args.size(); ++i) { this->Update(call->args[i], NullValue()); } } }; flist_.push_back(flazy); } void VisitExpr_(const TupleNode* op) final { ExprVisitor::VisitExpr_(op); // do not support pass scale through tuple for now. auto flazy = [this, op]() { for (const Expr& field : op->fields) { this->Update(field, NullValue()); } }; flist_.push_back(flazy); } void VisitExpr_(const IfNode* op) final { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue // it means fuse signal cannot pass // through into these subexpressions. auto flazy = [this, op]() { this->Update(op->cond, NullValue()); this->Update(op->true_branch, NullValue()); this->Update(op->false_branch, NullValue()); }; flist_.push_back(flazy); } }; static bool IsIntInArray(const Array& axis, int v) { for (size_t i = 0; i < axis.size(); i++) { if (axis[i] == v) return true; } return false; } static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, const Array& axis) { Array arr; for (size_t i = 0; i < shape.size(); i++) { if (IsIntInArray(axis, i)) { auto node = shape[i].as(); if (!node) { // if the shape is not a constant, use normal transform return Expr(); } arr.push_back(node->value); } else { arr.push_back(1); } } return MakeReshape(scale, std::move(arr)); } // if only one axis, use expand dim. Else, use reshape static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, const Array& axis) { if (axis.size() > 1) { return ReshapeToMatchAxis(scale, shape, axis); } else { return ExpandBiasToMatchAxis(scale, shape.size(), axis); } } //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- // Intermediate operators Array ReluForwardPrep(const Call& call, const Message& out_message) { if (out_message.defined()) { return {Message(out_message->axes, true)}; } return {out_message}; } Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_object(); rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; return Expr(rnode); } RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); RELAY_REGISTER_OP("nn.leaky_relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array AddSubForwardPrep(const Call& call, const Message& out_message) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); auto none = NullValue(); if (out_message.defined()) { if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) { return {out_message, none}; } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) { return {none, out_message}; } } return {none, none}; } Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); if (!slhs && !srhs) return Expr(); const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); auto rnode = make_object(); if (slhs != nullptr) { ICHECK(srhs == nullptr); ICHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes); if (!scale.defined()) { return Expr(); } Expr rhs = Divide(new_args[1], scale); rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { ICHECK(srhs != nullptr); ICHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes); if (!scale.defined()) { return Expr(); } Expr lhs = Divide(new_args[0], scale); rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; } return Expr(rnode); } RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") .set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { if (!message.defined()) return Expr(); const auto& expected_out_axes = message->axes; ICHECK(expected_out_axes.defined() && expected_out_axes.size()); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); ICHECK(!slhs && !srhs); const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); Expr lhs = new_args[0]; Expr rhs = new_args[1]; auto rnode = make_object(); if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && (!message->require_positive || IsAllPositiveConstant(rhs))) { rnode->value = lhs; rnode->scale = rhs; rnode->axes = expected_out_axes; return Expr(rnode); } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && (!message->require_positive || IsAllPositiveConstant(lhs))) { rnode->value = rhs; rnode->scale = lhs; rnode->axes = expected_out_axes; return Expr(rnode); } else { return Expr(); } } RELAY_REGISTER_OP("multiply") .set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // TODO(tvm-team) support general data layout // by transforming weight const auto* param = call->attrs.as(); ICHECK(param != nullptr); Layout data_layout(param->data_layout); Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c')); ICHECK_GE(c_big_axis, 0); Message none = NullValue(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. // We only need to change the Prep and Mutate function. // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout Array arr{c_big_axis}; if (c_small_axis >= 0) { arr.push_back(c_small_axis); } return {Message(arr, false), none}; } } return {none, none}; } // Conv2D consumes the scale axis during transformation. Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); const auto* sweight = new_args[1].as(); if (sdata == nullptr) return Expr(); if (sweight != nullptr) return Expr(); const auto* param = ref_call->attrs.as(); ICHECK(param != nullptr); Layout data_layout(param->data_layout); Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); ICHECK_GE(c_big_axis, 0); int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); ICHECK(is_simple || is_blocking); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); ICHECK(param->groups == 1 || is_depthwise_conv2d); Expr weight = new_args[1]; // match the ic_axis if (is_depthwise_conv2d) { if (is_simple) { Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis}); weight = Multiply(weight, scale); } else { weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, {big_ko_axis, small_ko_axis})); if (!weight.defined()) return Expr(); } } else { if (is_simple) { Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis}); weight = Multiply(weight, scale); } else { weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, {big_ki_axis, small_ki_axis})); if (!weight.defined()) return Expr(); } } // return transformed conv2d return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } RELAY_REGISTER_OP("nn.conv2d").set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); for (const auto& m : message) { if (m.second.defined()) { // run optimization auto fcontext = [&](const Call& call) -> ObjectRef { auto it = message.find(call.get()); if (it != message.end()) { return it->second; } else { return ObjectRef(nullptr); } }; return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } } // no messages - no optimization return data; } //---------------------------------------- // Implement backward transformations. //---------------------------------------- class BackwardTransformer; /*! * \brief Preparation function for for pass scale backward. * \param call The call node. * \param in_messages Messages from the input containing allowed input scaling and whether * positive scale is required. * \return Message containing the result scaling on axes of the input. */ using FBackwardPrep = TypedPackedFunc& in_messages)>; using FBackwardTransform = TypedPackedFunc; //---------------------------------------------- // Generic Visitors for FScaleAxisBackward //---------------------------------------------- class BackwardPrep : private MixedModeVisitor { public: // The message on each node. std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); return std::move(message_); } private: // The message on each node. std::unordered_map message_; // reference counter of an internal expr std::unordered_map ref_counter_; // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); static const auto& fprep = Op::GetAttrMap("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); ICHECK(rit != ref_counter_.end()); // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; Array in_messages = GetInMessages(call); Message out_message = f(GetRef(call), in_messages); if (out_message.defined()) { message_[call] = out_message; } } Array GetInMessages(const CallNode* call) { Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); if (it != message_.end()) { in_messages.push_back(it->second); } else { in_messages.push_back(NullValue()); } } return in_messages; } }; /* * Hybrid apporach is used with the transformation * itself is recursive but the traversal is non-recursive */ class BackwardTransformerNode : public Object, private MixedModeMutator { public: using MixedModeMutator::Mutate; // Run forward transform. Expr Fold(Expr expr) { message_ = BackwardPrep().Prepare(expr); for (const auto& m : message_) { if (m.second.defined()) { // run optimization return this->Mutate(expr); } } // no messages - no optimization return expr; } /*! * \brief Transform the expr to consider the scaling. */ Expr Transform(const Expr& expr, Message message, Expr scale); /*! * \brief Get the message propogated to the expr. * \param expr The expresison. * \return The message containing the expected axes and whether positive scale is required. */ Message GetMessage(const Expr& expr) const { auto it = message_.find(expr.get()); if (it != message_.end()) return it->second; return NullValue(); } // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object); private: // Valid axes on each node. std::unordered_map message_; // Override mutation of call. Expr Rewrite_(const CallNode* call_node, const Expr& post) final { return Transform(GetRef(call_node), NullValue(), NullValue()); } public: Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); } }; class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); } using ContainerType = BackwardTransformerNode; }; /*! * \brief Transform the expr to consider the scaling. * * \param expr The input expression. * \param message The axes to scale. * \param scale The scale applied to the axes. * \return The result of transformation. */ Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) { if (const CallNode* call_node = expr.as()) { static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); const Call call = GetRef(call_node); // ignore if there is a message if (!message.defined()) { const auto it = memo_.find(call); if (it != memo_.end()) { return it->second; } } Expr new_expr = NullValue(); if (f != nullptr) { new_expr = f(call, message, scale, GetRef(this)); } else { ICHECK(!message.defined()) << "outstanding scale"; new_expr = NormalCallTransform(call.operator->()); } memo_[call] = new_expr; return new_expr; } else { ICHECK(!message.defined()) << "outstanding scale"; return this->Mutate(expr); } } //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- // Intermediate operators Message ReluBackwardPrep(const Call& call, const Array& in_messages) { if (in_messages[0].defined()) { return Message(in_messages[0]->axes, true); } return in_messages[0]; } Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Expr input = transformer->Transform(call->args[0], message, scale); return Call(call->op, {input}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); StructuralEqual equal; if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; } else if (in_messages[1].defined() && MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { return in_messages[1]; } else if (in_messages[0].defined() && in_messages[1].defined() && equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. return in_messages[0]; } else { auto res = NullValue(); return res; } } Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); StructuralEqual equal; if (lhs_message.defined() && rhs_message.defined()) { ICHECK(equal(lhs_message->axes, rhs_message->axes)); ICHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform(call->args[1], message, scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (lhs_message.defined()) { ICHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes); if (!rhs_scale.defined()) { return transformer->NormalCallTransform(call.operator->()); } rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { ICHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes); if (!lhs_scale.defined()) { return transformer->NormalCallTransform(call.operator->()); } lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { LOG(FATAL) << "outstanding scale"; return Expr(); } } RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("subtract") .set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { ICHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); if (lhs_message.defined()) { ICHECK(lhs_message->axes.defined() && lhs_message->axes.size()); // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { return transformer->Transform(call->args[0], lhs_message, rhs); } } else if (rhs_message.defined()) { ICHECK(rhs_message->axes.defined() && rhs_message->axes.size()); Expr lhs = call->args[0]; if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { return transformer->Transform(call->args[1], rhs_message, lhs); } } return transformer->NormalCallTransform(call.operator->()); } RELAY_REGISTER_OP("multiply") .set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); // Consumer operators // Conv2D send out requirement of axis folding. Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) { const auto* param = call->attrs.as(); ICHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C')); int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c')); ICHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. // We only need to change the Prep and Mutate function. // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout Array arr{c_big_axis}; if (c_small_axis >= 0) { arr.push_back(c_small_axis); } return Message(arr, false); } } return NullValue(); } // Conv2D consumes the scale axis during transformation. Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } const auto* param = call->attrs.as(); ICHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C')); ICHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); ICHECK(param->groups == 1 || is_depthwise_conv2d); bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); ICHECK(is_simple || is_blocking); Expr data = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr weight = transformer->Transform(call->args[1], NullValue(), NullValue()); // scale on input for deptwise. Expr wscale; if (is_simple) { wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis}); } else { wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, {big_ko_axis, small_ko_axis}); if (!wscale.defined()) { return transformer->NormalCallTransform(call.operator->()); } } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { return make_object()->Fold(data); } } // namespace fold_scale_axis namespace transform { Pass ForwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis); Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, "FoldScaleAxis"); return pass; } TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis); } // namespace transform } // namespace relay } // namespace tvm