/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \file convert_sparse_conv2d.cc * * \brief Mutate conv2d operator to sparse conv2d operator */ #include #include #include #include #include #include #include #include #include namespace tvm { namespace relay { // Search conv2d op weight name from Expr class Conv2dOpWeightVisitor : private ExprVisitor { public: Conv2dOpWeightVisitor() : conv2d_op_(Op::Get("nn.conv2d")) {} Array Search(const Expr& expr) { VisitExpr(expr); return memo_; } private: void VisitExpr_(const CallNode* n) final { if (n->op == conv2d_op_) { const auto weight = n->args[1].as(); if (weight) { memo_.push_back(weight->name_hint()); } } for (const auto& arg : n->args) { VisitExpr(arg); } } // Cache op const Op& conv2d_op_; Array memo_; }; // SearchConv2dOpWeight Array SearchConv2dOpWeight(const Expr& e) { return Conv2dOpWeightVisitor().Search(e); } TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(SearchConv2dOpWeight); // Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, const Array>& weight_shape, const String& layout, int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; kernel_size_ = kernel_size; for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); std::string k = weight_name[i].as()->data; const auto& ws = weight_shape[i]; std::vector v(ws.size()); for (size_t j = 0; j < ws.size(); ++j) { v[j] = ws[j].as()->value; } target_weights_.emplace(k, v); } } Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == conv2d_op_) { const auto weight = pre->args[1].as(); if (weight) { if (target_weights_.count(weight->name_hint())) { const auto& prefix = weight->name_hint(); const auto& ws = target_weights_.at(prefix); const auto data = post.as()->args[0]; relay::TensorType ws_data_type, ws_indices_type, ws_indptr_type; if (ws.size() == 5) { ws_data_type = relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32)); ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32)); } else if (ws.size() == 4) { ws_data_type = relay::TensorType({ws.at(0), ws.at(1)}, DataType::Float(32)); ws_indices_type = relay::TensorType({ws.at(2)}, DataType::Int(32)); ws_indptr_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); } Var weight_data(prefix + ".data", ws_data_type); Var weight_indices(prefix + ".indices", ws_indices_type); Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } } } return post; } private: // Cached op const Op& conv2d_op_; const Op& sparse_conv2d_op_; std::unordered_map> target_weights_; String layout_; int kernel_size_; }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, const Array>& weight_shape, const String& layout, int kernel_size) { auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); return PostOrderRewrite(e, &rewriter); } template auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence) { return std::make_tuple(arr[Is]...); } template auto unpack_to_tuple(elemTy* arr) { return unpack_to_tuple_internal(arr, std::make_index_sequence{}); } struct Range { size_t dim; explicit Range(size_t d) : dim(d) {} struct iterpoint { size_t val, lim; iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} size_t operator*() const { return val; } iterpoint operator/(const iterpoint& rhs) const { return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); } }; struct iterator { size_t val, lim; iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} bool operator!=(const iterator& rhs) const { return val != rhs.val; } void operator++() { ++val; } iterpoint operator*() const { return iterpoint(val, lim); } }; iterator begin() { return iterator(0, dim); } iterator end() { return iterator(dim, dim); } }; // Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` class Conv2dToSparseConv2dMutator2 : public ExprRewriter { public: Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, double sparse_thresh) : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), dev_cpu0_{DLDeviceType::kDLCPU, 0}, layout_(layout), kernel_size_(kernel_size), blockH_(blockH), blockW_(blockW), sparse_thresh_(sparse_thresh) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { // check op type & attrs const auto pre_attrs = pre->attrs.as(); if (!pre_attrs || pre_attrs->data_layout != layout_ || pre_attrs->strides[0].as()->value != 1 || pre_attrs->kernel_size[0].as()->value != kernel_size_) return post; // check constant weight const auto pre_weight_node = pre->args[1].as(); if (!pre_weight_node) return post; // check weight dtype & shape auto&& pre_weight = pre_weight_node->data; auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); int O, I, H, W; if (layout_ == "NCHW") { std::tie(O, I, H, W) = pre_weight_shape; } else { // NHWC std::tie(H, W, I, O) = pre_weight_shape; } int CO = O, CI = H * W * I; // copy to vector std::vector pre_weight_data(CO * CI); pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); if (layout_ == "NHWC") { std::vector tmp(pre_weight_data.size()); for (auto i : Range(CO)) for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; std::swap(tmp, pre_weight_data); } // convert to BSR std::vector wdata, block(blockH_ * blockW_); std::vector windices, windptr; for (auto bh : Range(CO / blockH_)) { windptr.push_back(windices.size()); for (auto bw : Range(CI / blockW_)) { int cntnnz = 0; for (auto i : Range(blockH_)) for (auto j : Range(blockW_)) { auto tmp = pre_weight_data[*(bh / i / bw / j)]; if (tmp) cntnnz++; block[*(i / j)] = tmp; } if (cntnnz) { wdata.insert(wdata.end(), block.begin(), block.end()); windices.push_back(*bw); } } } windptr.push_back(windices.size()); double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); if (sprate < sparse_thresh_) return post; // constrct return data int nnz = windices.size(); auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); // construct return call auto args = runtime::Array{post.as()->args[0], Constant(weight_data), Constant(weight_indices), Constant(weight_indptr)}; auto attrs = make_object(); attrs->layout = layout_; attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, args, Attrs(attrs)); } private: const Op& sparse_conv2d_op_; DLDevice dev_cpu0_; String layout_; int kernel_size_, blockH_, blockW_; double sparse_thresh_; }; // class Conv2dToSparseConv2dMutator2 Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, double sparse_thresh) { auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } namespace transform { // Convert a model with seperate weight info (already sparsified). Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); }; return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"}); } TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse").set_body_typed(Conv2dToSparse); // Convert a model with freezed params (sparsified in the pass). Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int blockW, double sparse_thresh) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { auto f0 = Downcast( Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); return f0; }; return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); } TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2); } // namespace transform } // namespace relay } // namespace tvm