/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/op/nn/nn.h * \brief Properties def of nn operators for sharing. */ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ #include #include #include #include #include #include #include "../op_common.h" namespace tvm { namespace relay { template bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); const auto* tensor_a = types[0].as(); const auto* tensor_b = types[1].as(); if (tensor_a == nullptr) return false; ICHECK(static_cast(tensor_a->shape.size()) != 0); const AttrType* param = attrs.as(); ICHECK(param != nullptr); // Default set to dense layout bool transpose_a = false; bool transpose_b = true; const auto& mattrs = attrs.as(); if (mattrs != nullptr) { transpose_a = mattrs->transpose_a; transpose_b = mattrs->transpose_b; } const Array& dshape = tensor_a->shape; Array oshape = dshape; tvm::PrimExpr reduce = dshape[dshape.size() - 1]; if (transpose_a) { reduce = dshape[dshape.size() - 2]; oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]); } if (param->units.defined()) { // validate the tensor_b shape is proper if defined // Assign tensor_b type const Array& wshape = transpose_b ? Array({param->units, reduce}) : Array({reduce, param->units}); // It is possible for tensor_b to be nullptr in which case we will use // data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly // present we will use that. auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); if (param->auto_scheduler_rewritten_layout.size() == 0) { // Normal case: assign result to reporter reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype)); } else { // If the layout is rewritten by auto-scheduler, // we just forcly apply the layout provided by auto-scheduler and // skip the normal inference logic. {} // do nothing } oshape.Set((oshape.size() - 1), param->units); } else { if (tensor_b == nullptr) return false; const Array& wshape = tensor_b->shape; // When tensor_b's layout has been rewritten, figure it out based on the // total number of elements and input dimensions. if (param->auto_scheduler_rewritten_layout.size() != 0) { PrimExpr tensor_b_elements = 1; for (size_t i = 0; i < wshape.size(); i++) { tensor_b_elements = tensor_b_elements * wshape[i]; } oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() - 1]); // Otherwise just pull it out of the tensor_b shape directly. } else { ICHECK(static_cast(tensor_b->shape.size()) == 2); if (!tensor_a->shape.back().as()) { ICHECK((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) || (!transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[0]))) << "MatmulRel: input dimension doesn't match," << " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" << tensor_b->shape; } oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]); } } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = tensor_a->dtype; } // assign output type reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } template bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; const AttrType* param = attrs.as(); ICHECK(param != nullptr); bool transpose_a = param->transpose_a; bool transpose_b = param->transpose_b; const Array& y_shape = param->auto_scheduler_rewritten_layout.size() == 0 ? y->shape : auto_scheduler::GetShapeFromRewrittenLayout( param->auto_scheduler_rewritten_layout, transpose_b ? tvm::runtime::Array({"b", "j", "k"}) : tvm::runtime::Array({"b", "k", "j"})); ICHECK(x->shape.size() == 3 && y_shape.size() == 3); const PrimExpr& xb = x->shape[0]; const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; const PrimExpr& xk = x->shape[transpose_a ? 1 : 2]; const PrimExpr& yb = y_shape[0]; const PrimExpr& yk = y_shape[transpose_b ? 2 : 1]; const PrimExpr& yj = y_shape[transpose_b ? 1 : 2]; bool is_dyn = false; for (size_t i = 0; i < 3; ++i) { if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { is_dyn = true; break; } } if (!is_dyn) { ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1)) << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y_shape; ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, " << " x shape=" << x->shape << ", y shape=" << y_shape; } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = x->dtype; } // assign output type const auto& out_b = xb->IsInstance() || yb->IsInstance() ? tir::Any() : max(xb, yb); reporter->Assign(types[2], TensorType(Array({out_b, xi, yj}), out_dtype)); return true; } } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_