/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/runtime/contrib/dnnl/dnnl_json_runtime.cc * \brief A simple JSON runtime for DNNL. */ #include #include #include #include #include #include "../json/json_node.h" #include "../json/json_runtime.h" #include "dnnl.hpp" namespace tvm { namespace runtime { namespace contrib { using namespace tvm::runtime; using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { using tag = dnnl::memory::format_tag; using dt = dnnl::memory::data_type; public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} const char* type_key() const { return "dnnl_json"; } void Init(const Array& consts) override { BuildEngine(); ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; // Setup constants entries for weights. SetupConstants(consts); } void Run() override { // Fill in the input buffers. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto eid = EntryID(input_nodes_[i], 0); // TODO(@comaniac): Support other data lengths. size_t offset_in_bytes = entry_out_mem_[eid].second * 4; size_t buffer_size = GetDataSize(*data_entry_[eid]); write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, offset_in_bytes); } // Invoke the engine through intepreting the stream. for (size_t i = 0; i < net_.size(); ++i) { net_.at(i).execute(stream_, net_args_.at(i)); } stream_.wait(); // Read output buffers. for (size_t i = 0; i < outputs_.size(); ++i) { auto eid = EntryID(outputs_[i]); size_t offset_in_bytes = entry_out_mem_[eid].second * 4; size_t buffer_size = GetDataSize(*data_entry_[eid]); read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, offset_in_bytes); } } private: // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); // Build subgraph engine. for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); if ("nn.conv2d" == op_name) { Conv2d(nid); } else if ("dnnl.conv2d_relu" == op_name) { Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu); } else if ("dnnl.conv2d_tanh" == op_name) { Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh); } else if ("dnnl.conv2d_sigmoid" == op_name) { Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic); } else if ("dnnl.conv2d_bias" == op_name) { Conv2d(nid, false, true); } else if ("dnnl.conv2d_bias_relu" == op_name) { Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu); } else if ("dnnl.conv2d_bias_tanh" == op_name) { Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh); } else if ("dnnl.conv2d_bias_sigmoid" == op_name) { Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic); } else if ("nn.dense" == op_name) { Dense(nid); } else if ("dnnl.dense_bias" == op_name) { Dense(nid, true); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); } else if ("nn.relu" == op_name) { Eltwise(nid, dnnl::algorithm::eltwise_relu); } else if ("tanh" == op_name) { Eltwise(nid, dnnl::algorithm::eltwise_tanh); } else if ("sigmoid" == op_name) { Eltwise(nid, dnnl::algorithm::eltwise_logistic); } else if ("add" == op_name) { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { Binary(nid, dnnl::algorithm::binary_mul); } else { LOG(FATAL) << "Unsupported op: " << op_name; } } } } // Bind a JSON graph node entry to a DNNL memory. dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc, size_t offset = 0) { auto eid = EntryID(entry); if (entry_out_mem_.count(eid) == 0) { return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset); } return entry_out_mem_[eid].first; } // Bind a JSON graph node entry to a given DNNL memory. dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem, size_t offset = 0) { auto eid = EntryID(entry); // Since the DNNL memory has been created before calling this function, we assume the entry // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. ICHECK_EQ(entry_out_mem_.count(eid), 0); // TODO(@comanic): Support other data types (i.e., int8). auto data_node = nodes_[entry.id_]; auto dltype = data_node.GetOpDataType()[entry.index_]; ICHECK_EQ(dltype.bits, 32); entry_out_mem_[eid] = {mem, offset}; return entry_out_mem_[eid].first; } void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false, dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) { auto node = nodes_[nid]; // Setup attributes. auto data_entry = node.GetInputs()[0]; auto weight_entry = node.GetInputs()[1]; dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; std::vector str_strides = node.GetAttr>("strides"); std::vector str_dilates = node.GetAttr>("dilation"); std::vector str_padding = node.GetAttr>("padding"); dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); dnnl::memory::dim N = input_shape[0], // batch size IC = input_shape[1], // input channels IH = input_shape[2], // input height IW = input_shape[3], // input width OC = weight_shape[0], // output channels KH = weight_shape[2], // weight height KW = weight_shape[3], // weight width PW_L = std::stoi(str_padding[1]), // width padding: left PW_R = std::stoi(str_padding[3]), // width padding: right PH_L = std::stoi(str_padding[0]), // height padding: top PH_R = std::stoi(str_padding[2]), // height padding: bottom SH = std::stoi(str_strides[0]), // height-wise stride SW = std::stoi(str_strides[1]), // weight-wise stride DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width // Memory shapes. dnnl::memory::dims src_dims = {N, IC, IH, IW}; dnnl::memory::dims weights_dims = {OC, IC, KH, KW}; if (groups > 1) { weights_dims = {groups, 1, IC / groups, KH, KW}; } dnnl::memory::dims bias_dims = {OC}; dnnl::memory::dims dst_dims = {N, OC, OH, OW}; dnnl::memory::dims strides_dims = {SH, SW}; dnnl::memory::dims dilates_dims = {DH, DW}; dnnl::memory::dims padding_dims_l = {PH_L, PW_L}; dnnl::memory::dims padding_dims_r = {PH_R, PW_R}; // Memory descriptions. auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, tag::any); auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, tag::any); auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::nchw); // Covn2d description. auto conv_desc = dnnl::convolution_forward::desc( dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l, padding_dims_r); // Enable elementwise post-ops dnnl::primitive_attr attr; if (has_elt) { dnnl::post_ops ops; ops.append_eltwise(1.f, algo, 0.f, 0.f); attr.set_post_ops(ops); } auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); // Push to the network. auto conv = dnnl::convolution_forward(conv2d_prim_desc); net_.push_back(conv); // Data memory. ICHECK_EQ(node.GetAttr>("data_layout")[0], "NCHW"); auto conv2d_src_memory = BindDNNLMemory(data_entry, {src_dims, dt::f32, tag::nchw}); // Weight memory. ICHECK_EQ(node.GetAttr>("kernel_layout")[0], "OIHW"); auto conv2d_weights_memory = BindDNNLMemory( weight_entry, {weights_dims, dt::f32, (groups > 1) ? tag::goihw : tag::oihw}); // Bias memory. auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; BindDNNLMemory(bias_entry, conv2d_bias_memory); } else { float bias[OC] = {0}; write_to_dnnl_memory(bias, conv2d_bias_memory, OC * sizeof(float)); } // Output memory. JSONGraphNodeEntry out_entry(nid, 0); auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc()); // Bind memory buffers. net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, {DNNL_ARG_BIAS, conv2d_bias_memory}, {DNNL_ARG_DST, conv2d_dst_memory}}); } void Dense(const size_t& nid, const bool has_bias = false) { auto node = nodes_[nid]; // Setup attributes. auto data_entry = node.GetInputs()[0]; auto weight_entry = node.GetInputs()[1]; dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; dnnl::memory::dim B = input_shape[0], // batch size IC = input_shape[1], // input channels OC = weight_shape[0]; // output channels // Memory shapes. dnnl::memory::dims data_dims = {B, IC}; dnnl::memory::dims weight_dims = {OC, IC}; dnnl::memory::dims bias_dims = {OC}; dnnl::memory::dims out_dims = {B, OC}; // Memory descriptions. auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc}); auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); // Dense description. auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, engine_); auto dense = dnnl::inner_product_forward(dense_prim_desc); net_.push_back(dense); // Memories. auto data_memory = BindDNNLMemory(data_entry, data_md); auto weight_memory = BindDNNLMemory(weight_entry, weight_md); // Bias memory. auto bias_memory = dnnl::memory(bias_md, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; BindDNNLMemory(bias_entry, bias_memory); } else { float bias[OC] = {0}; write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); } // Output memory. JSONGraphNodeEntry out_entry(nid, 0); auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_WEIGHTS, weight_memory}, {DNNL_ARG_BIAS, bias_memory}, {DNNL_ARG_DST, dst_memory}}); } void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; auto data_entry = node.GetInputs()[0]; auto gamma_entry = node.GetInputs()[1]; auto beta_entry = node.GetInputs()[2]; auto mean_entry = node.GetInputs()[3]; auto variance_entry = node.GetInputs()[4]; dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dim IC = data_shape[1]; float epsilon = std::stof(node.GetAttr>("epsilon")[0]); // Memory description. dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); // BN description. auto bn_desc = dnnl::batch_normalization_forward::desc( dnnl::prop_kind::forward_inference, data_md, epsilon, dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_); auto bn = dnnl::batch_normalization_forward(bn_prim_desc); net_.push_back(bn); // Memories. auto data_memory = BindDNNLMemory(data_entry, data_md); JSONGraphNodeEntry out_entry(nid, 0); auto out_memory = BindDNNLMemory(out_entry, data_md); auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc()); auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc()); // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but // assign an offset to beta data for runtime serialization. auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0); BindDNNLMemory(beta_entry, weight_memory, IC); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}, {DNNL_ARG_SCALE_SHIFT, weight_memory}, {DNNL_ARG_MEAN, mean_memory}, {DNNL_ARG_VARIANCE, variance_memory}}); } void Eltwise(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); auto elt_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); ICHECK(data_md == elt_prim_desc.dst_desc()); auto elt = dnnl::eltwise_forward(elt_prim_desc); net_.push_back(elt); auto data_memory = BindDNNLMemory(data_entry, data_md); JSONGraphNodeEntry out_entry(nid, 0); auto out_memory = BindDNNLMemory(out_entry, data_md); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); } void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; // Memory and compute description. std::vector data_dims; std::vector data_mds; std::vector data_memories; ICHECK_EQ(node.GetInputs().size(), 2U); for (auto entry : node.GetInputs()) { auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); data_dims.push_back(data_shape); data_mds.push_back(data_md); data_memories.push_back(BindDNNLMemory(entry, data_md)); } ICHECK(data_dims[0] == data_dims[1]); auto out_md = data_mds[0]; JSONGraphNodeEntry out_entry(nid, 0); auto out_memory = BindDNNLMemory(out_entry, out_md); auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); auto binary = dnnl::binary(binary_prim_desc); net_.push_back(binary); net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, {DNNL_ARG_SRC_1, data_memories[1]}, {DNNL_ARG_DST, out_memory}}); } // Read from DNNL memory (+offset) and write to the handle. inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, size_t offset = 0) { uint8_t* src = static_cast(mem.get_data_handle()); std::copy(src + offset, src + offset + size, static_cast(handle)); } // Read from the handle and write to DNNL memory (+offset). inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, size_t offset = 0) { uint8_t* dst = static_cast(mem.get_data_handle()); std::copy(reinterpret_cast(handle), reinterpret_cast(handle) + size, dst + offset); } // Generate DNNL memory description and infer the data layout by the given shape. inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) { dnnl::memory::desc data_md; switch (shape.size()) { case 2: data_md = dnnl::memory::desc({shape, dtype, tag::ab}); break; case 3: data_md = dnnl::memory::desc({shape, dtype, tag::abc}); break; case 4: data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); break; case 5: data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); break; default: LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); break; } return data_md; } /* The dnnl engine. */ dnnl::engine engine_; /* The dnnl stream. */ dnnl::stream stream_; /* The network layers that are represented in dnnl primitives. */ std::vector net_; /* The memory that is consumed by arguments. */ std::vector> net_args_; /* The entry ID to its corresponding output memory. */ std::unordered_map> entry_out_mem_; }; runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); return runtime::Module(n); } TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib } // namespace runtime } // namespace tvm