/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file split_args.cc */ #include #include #include "../op/annotation/annotation.h" #include "./pattern_utils.h" namespace tvm { namespace relay { class ArgumentSplitter : public ExprRewriter { public: explicit ArgumentSplitter(int max_function_args) : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {} Expr Rewrite_(const CallNode* call, const Expr& post) final { if (max_function_args_ < 0) return post; if (call->op == concat_op_) { auto tuple_node = call->args[0].as(); const auto param = call->attrs.as(); int outputsNum = 1; if (const auto* tuple_type = call->checked_type().as()) { outputsNum = tuple_type->fields.size(); } const int limit = max_function_args_ - outputsNum; int argsNum = tuple_node->fields.size(); if (argsNum < limit) return post; int splitNum = argsNum / limit; splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; std::vector splitted(splitNum); for (int i = 0; i < splitNum; ++i) { int startIdx = i * limit; int argsCount = std::min(limit, argsNum - startIdx); tvm::Array args; args.reserve(argsCount); for (int j = 0; j < argsCount; ++j) { args.push_back(tuple_node->fields[j + startIdx]); } Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(args)); Expr body = MakeConcatenate(new_tuple, param->axis); splitted[i] = StopFusion(body); } tvm::Array tuple_args(splitted); Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(tuple_args)); return MakeConcatenate(new_tuple, param->axis); } return post; } private: const int max_function_args_; const Op& concat_op_; }; Expr SplitArgs(const Expr& expr, int max_function_args) { auto rewriter = ArgumentSplitter(max_function_args); return PostOrderRewrite(expr, &rewriter); } namespace transform { Pass SplitArgs(int max_function_args) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(SplitArgs(f, max_function_args)); }; return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); } // namespace transform } // namespace relay } // namespace tvm