/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/ir/expr.cc * \brief The expression AST nodes of Relay. */ #include <tvm/ir/module.h> #include <tvm/relay/expr.h> #include <tvm/target/virtual_device.h> namespace tvm { VirtualDevice RelayExprNode::virtual_device() const { if (virtual_device_.defined()) { return Downcast<VirtualDevice>(this->virtual_device_); } return VirtualDevice::FullyUnconstrained(); } namespace relay { using tvm::ReprPrinter; using namespace tvm::runtime; Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr<ConstantNode> n = make_object<ConstantNode>(); n->data = std::move(data); n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { return Constant(data); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const ConstantNode*>(ref.get()); const PackedFunc* fprint = Registry::Get("relay._constant_repr"); ICHECK(fprint) << "unable to find printing function for constants"; std::string data = (*fprint)(GetRef<Constant>(node)); p->stream << "Constant(" << data << ")"; }); TensorType ConstantNode::tensor_type() const { auto dtype = DataType(data->dtype); Array<tvm::PrimExpr> shape; for (int i = 0; i < data->ndim; i++) { ICHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max()); ICHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min()); shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorType(shape, dtype); } Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) { ObjectPtr<TupleNode> n = make_object<TupleNode>(); n->fields = std::move(fields); n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) { return Tuple(fields, span); }); Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Array<Expr> fields = opt_fields.value_or(tuple->fields); VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; if (fields.size() == tuple->fields.size()) { for (size_t i = 0; i < fields.size(); i++) { all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); } } else { all_fields_unchanged = false; } all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); cow_tuple_node->fields = fields; cow_tuple_node->virtual_device_ = virtual_device; cow_tuple_node->span = span; } return tuple; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const TupleNode*>(ref.get()); p->stream << "Tuple(" << node->fields << ")"; }); Var::Var(Id vid, Type type_annotation, Span span) { ObjectPtr<VarNode> n = make_object<VarNode>(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); n->span = std::move(span); data_ = std::move(n); } /* static */ Var Var::GenSym(Type type_annotation, Span span) { static size_t next_id = std::atomic<size_t>(0); std::ostringstream os; os << "x_" << next_id++; return Var(os.str(), std::move(type_annotation), std::move(span)); } Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Id vid = opt_vid.value_or(var->vid); Type type_annotation = opt_type_annotation.value_or(var->type_annotation); VirtualDevice virtual_device = opt_virtual_device.value_or(var->virtual_device()); Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && span.same_as(var->span); if (!unchanged) { VarNode* cow_var_node = var.CopyOnWrite(); cow_var_node->vid = vid; cow_var_node->type_annotation = type_annotation; cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } return var; } TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) { return Var(str, type_annotation); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const VarNode*>(ref.get()); p->stream << "Var(" << node->name_hint(); if (node->type_annotation.defined()) { p->stream << ", ty="; p->Print(node->type_annotation); } p->stream << ")"; }); Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) { ObjectPtr<CallNode> n = make_object<CallNode>(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); n->span = std::move(span); data_ = std::move(n); } Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args, Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr op = opt_op.value_or(call->op); Array<Expr> args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); Array<Type> type_args = opt_type_args.value_or(call->type_args); VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); // Check that the args are unchanged if (unchanged) { bool all_args_unchanged = true; if (args.size() == call->args.size()) { for (size_t i = 0; i < args.size(); i++) { all_args_unchanged &= args[i].same_as(call->args[i]); } } else { all_args_unchanged = false; } unchanged &= all_args_unchanged; } // Check that the type_args are unchanged if (unchanged) { bool all_type_args_unchanged = true; if (type_args.size() == call->type_args.size()) { for (size_t i = 0; i < type_args.size(); i++) { all_type_args_unchanged &= type_args[i].same_as(call->type_args[i]); } } else { all_type_args_unchanged = false; } unchanged &= all_type_args_unchanged; } if (!unchanged) { CallNode* cow_call_node = call.CopyOnWrite(); cow_call_node->op = op; cow_call_node->args = args; cow_call_node->attrs = attrs; cow_call_node->type_args = type_args; cow_call_node->virtual_device_ = virtual_device; cow_call_node->span = span; } return call; } TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") .set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) { return Call(op, args, attrs, type_args, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const CallNode*>(ref.get()); p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " << node->type_args << ")"; }); Let::Let(Var var, Expr value, Expr body, Span span) { ObjectPtr<LetNode> n = make_object<LetNode>(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); n->span = std::move(span); data_ = std::move(n); } Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optional<Expr> opt_body, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Var var = opt_var.value_or(let->var); Expr value = opt_value.value_or(let->value); Expr body = opt_body.value_or(let->body); VirtualDevice virtual_device = opt_virtual_device.value_or(let->virtual_device()); Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && span.same_as(let->span); if (!unchanged) { LetNode* cow_let_node = let.CopyOnWrite(); cow_let_node->var = var; cow_let_node->value = value; cow_let_node->body = body; cow_let_node->virtual_device_ = virtual_device; cow_let_node->span = span; } return let; } TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { return Let(var, value, body); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const LetNode*>(ref.get()); p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; }); If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr<IfNode> n = make_object<IfNode>(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); n->span = std::move(span); data_ = std::move(n); } If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch, Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); VirtualDevice virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); if (!unchanged) { IfNode* cow_if_node = if_expr.CopyOnWrite(); cow_if_node->cond = cond; cow_if_node->true_branch = true_branch; cow_if_node->false_branch = false_branch; cow_if_node->virtual_device_ = virtual_device; cow_if_node->span = span; } return if_expr; } TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { return If(cond, true_branch, false_branch); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const IfNode*>(ref.get()); p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " << node->false_branch << ")"; }); TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>(); n->tuple = std::move(tuple); n->index = index; n->span = std::move(span); data_ = std::move(n); } TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple, Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && span.same_as(tuple_get_item->span); if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); cow_tuple_get_item_node->tuple = tuple; cow_tuple_get_item_node->index = index; cow_tuple_get_item_node->span = span; cow_tuple_get_item_node->virtual_device_ = virtual_device; } return tuple_get_item; } TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const TupleGetItemNode*>(ref.get()); p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); RefCreate::RefCreate(Expr value, Span span) { ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>(); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); } RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr value = opt_value.value_or(ref_create->value); VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); cow_ref_create_node->value = value; cow_ref_create_node->virtual_device_ = virtual_device; cow_ref_create_node->span = span; } return ref_create; } TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { return RefCreate(value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const RefCreateNode*>(ref.get()); p->stream << "RefCreateNode(" << node->value << ")"; }); RefRead::RefRead(Expr ref, Span span) { ObjectPtr<RefReadNode> n = make_object<RefReadNode>(); n->ref = std::move(ref); n->span = std::move(span); data_ = std::move(n); } RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr ref = opt_ref.value_or(ref_read->ref); VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); cow_ref_read_node->ref = ref; cow_ref_read_node->virtual_device_ = virtual_device; cow_ref_read_node->span = span; } return ref_read; } TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const RefReadNode*>(ref.get()); p->stream << "RefReadNode(" << node->ref << ")"; }); RefWrite::RefWrite(Expr ref, Expr value, Span span) { ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>(); n->ref = std::move(ref); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); } RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value, Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { Expr ref = opt_ref.value_or(ref_write->ref); Expr value = opt_value.value_or(ref_write->value); VirtualDevice virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && span.same_as(ref_write->span); if (!unchanged) { RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); cow_ref_write_node->ref = ref; cow_ref_write_node->value = value; cow_ref_write_node->virtual_device_ = virtual_device; cow_ref_write_node->span = span; } return ref_write; } TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { return RefWrite(ref, value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const RefWriteNode*>(ref.get()); p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) { return temp->Realize(); }); TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any(); }); /* * Non-recursive traversal with dismantling unused call nodes, * a derivative from ExpandDataflow method */ inline void Dismantle(const Expr& expr) { std::stack<std::pair<Expr, bool>> stack; auto fpush_to_stack = [&stack](const Expr& expr) { // do not visit nodes with more than 2 refs (one can be in stack) if (expr.use_count() < 3) { stack.push({expr, false}); } }; fpush_to_stack(expr); while (stack.size() > 0) { const auto& node = stack.top().first; if (stack.top().second) { // dismantle node // +1 ref in stack/deque; if (node.use_count() < 3) { if (auto* op = const_cast<CallNode*>(node.as<CallNode>())) { op->args = Array<Expr>(); } if (auto* op = const_cast<LetNode*>(node.as<LetNode>())) { op->body = Expr(); } } // eject stack.pop(); } else { stack.top().second = true; // special handling if (const auto* call_node = node.as<CallNode>()) { // do not process args if used elsewhere if (call_node->args.use_count() < 2) { for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) { fpush_to_stack(*it); } } } else if (const auto* tuple_node = node.as<TupleNode>()) { // do not process fields if used elsewhere if (tuple_node->fields.use_count() < 2) { for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) { fpush_to_stack(*it); } } } else if (const auto* tuple_get_item_node = node.as<TupleGetItemNode>()) { // do not process tuple if used elsewhere if (tuple_get_item_node->tuple.use_count() < 2) { fpush_to_stack(tuple_get_item_node->tuple); } } else if (const auto* let_node = node.as<LetNode>()) { // do not process let if used elsewhere if (let_node->body.use_count() < 2) { fpush_to_stack(let_node->body); } } } } } /* * Non-recursive destructor */ Call::~Call() { // attempt to dismantle if referenced one or zero times if (this->use_count() < 2) { if (this->as<CallNode>() && this->as<CallNode>()->args.size()) { Dismantle(*this); } } } /* * CallNode's deleter */ void CallNode::Deleter_(Object* ptr) { auto p = reinterpret_cast<CallNode*>(ptr); // resore original deleter p->deleter_ = p->saved_deleter_; // create Call reference in order to invoke ~Call auto c = GetRef<Call>(p); } /* * Non-recursive destructor */ Let::~Let() { // attempt to dismantle if referenced one or zero times if (this->use_count() < 2) { if (this->as<LetNode>() && this->as<LetNode>()->body.defined()) { Dismantle(*this); } } } /* * LetNode's deleter */ void LetNode::Deleter_(Object* ptr) { auto p = reinterpret_cast<LetNode*>(ptr); // resore original deleter p->deleter_ = p->saved_deleter_; // create Let reference in order to invoke ~Let auto c = GetRef<Let>(p); } } // namespace relay } // namespace tvm