/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tensor.cc */ #include #include #include #include #include namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { return IterVar(dom, Var(tag), kThreadIndex, tag); } IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name), kCommReduce); } Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor PrimExpr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); return operator()(arr); } PrimExpr Tensor::operator()(Array indices) const { if (ndim() != 0) { ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " << "ndim = " << ndim() << ", indices.size=" << indices.size(); } return ProducerLoad((*this), indices); } String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } Tensor Operation::output(size_t i) const { auto node = make_object(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); node->shape = (*this)->output_shape(i); return Tensor(node); } Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; n->value_index = value_index; data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.Tensor") .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { return Tensor(shape, dtype, op, value_index); }); TVM_REGISTER_NODE_TYPE(TensorNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* t = static_cast(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; }); // TensorIntrin TensorIntrin::TensorIntrin(std::string name, Operation op, Array inputs, Array buffers, Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); n->inputs = std::move(inputs); n->buffers = std::move(buffers); n->scalar_params = std::move(scalar_params); n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.TensorIntrin") .set_body_typed([](std::string name, Operation op, Array inputs, Array buffers, Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, reduce_update); }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); // TensorIntrinCall TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis, Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); n->scalar_inputs = std::move(scalar_inputs); data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.TensorIntrinCall") .set_body_typed([](TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis, Array scalar_inputs) { return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* n = static_cast(node.get()); p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; }); TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); // Other tensor ops. TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { return static_cast(std::hash()(tensor)); }); TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm