/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file topk.cc * \brief TopK operators */ #include <tvm/relay/attrs/algorithm.h> #include <tvm/relay/op.h> #include <tvm/tir/op.h> namespace tvm { namespace relay { namespace dyn { bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, k, result] const TopKAttrs* param = attrs.as<TopKAttrs>(); ICHECK_EQ(types.size(), 3); const auto* data = types[0].as<TensorTypeNode>(); const auto* k = types[1].as<TensorTypeNode>(); if (data == nullptr) { ICHECK(types[0].as<IncompleteTypeNode>()) << "tile: expect input type to be TensorType but get " << types[0]; return false; } if (k == nullptr) { ICHECK(types[1].as<IncompleteTypeNode>()) << "tile: expect input type to be TensorType but get " << types[1]; return false; } ICHECK(k->shape.size() <= 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )"; if (k->shape.size() == 1) { const IntImmNode* k_shape = k->shape[0].as<IntImmNode>(); ICHECK(k_shape) << "Parameter k must have static shape"; ICHECK_EQ(k_shape->value, 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )"; } int ndim = data->shape.size(); int axis = param->axis; if (axis < 0) { axis += ndim; } ICHECK(axis >= 0 && axis < ndim); Array<IndexExpr> out_shape; for (int i = 0; i < ndim; ++i) { if (i != axis) { out_shape.push_back(data->shape[i]); } else { out_shape.push_back(Any()); } } auto values_ty = TensorType(out_shape, data->dtype); auto indices_ty = TensorType(out_shape, param->dtype); if (param->ret_type == "both") { reporter->Assign(types[2], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { reporter->Assign(types[2], values_ty); } else if (param->ret_type == "indices") { reporter->Assign(types[2], indices_ty); } else { LOG(FATAL) << "Unsupported ret type: " << param->ret_type; } return true; } Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object<TopKAttrs>(); attrs->axis = axis; attrs->ret_type = ret_type; attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("dyn.topk"); return Call(op, {data, k}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.dyn._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("dyn.topk") .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) .set_num_inputs(2) .set_attrs_type<TopKAttrs>() .add_argument("data", "Tensor", "Input data.") .add_argument("k", "Tensor", "Number of top elements.") .set_support_level(6) .add_type_rel("DynTopK", TopKRel); } // namespace dyn } // namespace relay } // namespace tvm