/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/tir/stmt.cc */ #include #include #include #include #include #include "buffer_common.h" namespace tvm { namespace tir { // LetStmt LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); auto vdtype = value.dtype(); // It is still valid to bind a pointer type // var to a value that is of type handle. if (var->type_annotation.as()) { ICHECK(vdtype.is_handle()); } else { ICHECK_EQ(value.dtype(), var.dtype()); } ObjectPtr node = make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.LetStmt") .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { return LetStmt(var, value, body, span); }); TVM_REGISTER_NODE_TYPE(LetStmtNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "let " << op->var << " = "; p->Print(op->value); p->stream << '\n'; p->Print(op->body); }); // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); n->body = std::move(body); n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.AttrStmt") .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { return AttrStmt(node, attr_key, value, body, span); }); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "// attr ["; p->Print(op->node); p->stream << "] " << op->attr_key << " = "; p->Print(op->value); p->stream << '\n'; p->Print(op->body); }); // AssertStmt AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; ObjectPtr node = make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(AssertStmtNode); TVM_REGISTER_GLOBAL("tir.AssertStmt") .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span span) { if (const auto* str = message.as()) { auto msg = StringImm(str->data); return AssertStmt(condition, msg, body, span); } else { return AssertStmt(condition, Downcast(message), body, span); } }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "assert("; p->Print(op->condition); p->stream << ", "; p->Print(op->message); p->stream << ")\n"; p->Print(op->body); }); // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding, Map annotations, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); ICHECK(extent.dtype().is_scalar()); ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); node->kind = kind; node->body = std::move(body); node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.For").set_body_typed( [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional thread_binding, Optional> annotations, Span span) { return For(loop_var, min, extent, static_cast(kind), body, thread_binding, annotations.value_or(Map()), span); }); TVM_REGISTER_NODE_TYPE(ForNode); std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) switch (type) { case ForKind::kSerial: out << "for"; break; case ForKind::kParallel: out << "parallel"; break; case ForKind::kUnrolled: out << "unrolled"; break; case ForKind::kVectorized: out << "vectorized"; break; case ForKind::kThreadBinding: out << "launch_thread"; break; } return out; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->kind << " (" << op->loop_var << ", "; p->Print(op->min); p->stream << ", "; p->Print(op->extent); p->stream << ") {\n"; p->indent += 2; p->Print(op->body); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); // While While::While(PrimExpr condition, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(condition.dtype().is_scalar()); ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; ICHECK(body.defined()); ObjectPtr node = make_object(); node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { return While(condition, body, span); }); TVM_REGISTER_NODE_TYPE(WhileNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "while(" << op->condition << ") {\n"; p->indent += 2; p->Print(op->body); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); // Store Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); // Assume that the array elements have 1 lane, unless a type // annotation tells us otherwise. int element_lanes = 1; auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); if (pointer_type.first) { // Currently cannot check element type of array, see Load::Load // for details. // TODO(Lunderberg): Uncomment this check once it can be applied. // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of()) // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " // << buffer_var->name_hint << " of type " << pointer_type.second; element_lanes = pointer_type.second.lanes(); } ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || (value.dtype().lanes() == index.dtype().lanes())); ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || (value.dtype().lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->value = std::move(value); node->index = std::move(index); node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { PrimExpr value = args[1]; if (args.size() == 3) { *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()), Span()); } else if (args.size() == 4) { *ret = Store(args[0], value, args[2], args[3], Span()); } else { *ret = Store(args[0], value, args[2], args[3], args[4]); } }); TVM_REGISTER_NODE_TYPE(StoreNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->buffer_var << "["; p->Print(op->index); p->stream << "] = "; p->Print(op->value); if (!is_one(op->predicate)) { p->stream << " if "; p->Print(op->predicate); } p->stream << '\n'; }); // ProducerStore ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, Span span) { ObjectPtr node = make_object(); node->producer = std::move(producer); node->value = std::move(value); node->indices = std::move(indices); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerStore") .set_body_typed([](DataProducer producer, PrimExpr value, Array indices, Span span) { return ProducerStore(producer, value, indices, span); }); TVM_REGISTER_NODE_TYPE(ProducerStoreNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->producer->GetNameHint() << "["; for (size_t i = 0; i < op->indices.size(); ++i) { p->Print(op->indices[i]); if (i < op->indices.size() - 1) p->stream << ", "; } p->stream << "]"; p->stream << " ="; p->Print(op->value); p->stream << '\n'; }); // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation << "). The data type should be an element of the pointer type."; for (size_t i = 0; i < extents.size(); ++i) { ICHECK(extents[i].defined()); ICHECK(extents[i].dtype().is_scalar()); } ICHECK(body.defined()); ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { result *= int_size->value; if (result > std::numeric_limits::max()) { return 0; } } else { return 0; } } return static_cast(result); } TVM_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { return Allocate(buffer_var, type, extents, condition, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); const auto* ptr_type = op->buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; for (size_t i = 0; i < op->extents.size(); ++i) { p->stream << " * "; p->Print(op->extents[i]); } p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); } p->stream << "\n"; p->Print(op->body); }); // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); ICHECK(bounds[i]->min.dtype().is_scalar()); ICHECK(bounds[i]->extent.dtype().is_scalar()); } ICHECK(body.defined()); ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); ObjectPtr node = make_object(); node->producer = std::move(producer); node->bounds = std::move(bounds); node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "producer_realize " << op->producer->GetNameHint() << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); p->stream << ", "; p->Print(op->bounds[i]->extent); p->stream << "]"; if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); } p->stream << " {\n"; p->indent += 2; p->Print(op->body); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); // Prefetch Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { data_ = make_object(buffer, bounds, span); } TVM_REGISTER_GLOBAL("tir.Prefetch") .set_body_typed([](Buffer buffer, Array bounds, Span span) { return Prefetch(buffer, bounds, span); }); TVM_REGISTER_NODE_TYPE(PrefetchNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "prefetch " << op->buffer << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); p->stream << ", "; p->Print(op->bounds[i]->extent); p->stream << "]"; if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; }); // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { auto node = make_object(); node->seq = std::move(seq); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); TVM_REGISTER_NODE_TYPE(SeqStmtNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); for (Stmt stmt : op->seq) { p->Print(stmt); } }); // IfThenElse IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { ICHECK(condition.defined()); ICHECK(then_case.defined()); // else_case may be null. ObjectPtr node = make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(IfThenElseNode); TVM_REGISTER_GLOBAL("tir.IfThenElse") .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { return IfThenElse(condition, then_case, else_case, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); while (true) { p->stream << "if (" << op->condition << ") {\n"; p->indent += 2; p->Print(op->then_case); p->indent -= 2; if (!op->else_case.defined()) { break; } if (const IfThenElseNode* nested_if = op->else_case.as()) { p->PrintIndent(); p->stream << "} else "; op = nested_if; } else { p->PrintIndent(); p->stream << "} else {\n"; p->indent += 2; p->Print(op->else_case); p->indent -= 2; break; } } p->PrintIndent(); p->stream << "}\n"; }); // Evaluate Evaluate::Evaluate(PrimExpr value, Span span) { ICHECK(value.defined()); ObjectPtr node = make_object(); node->value = std::move(value); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { return Evaluate(value, span); }); TVM_REGISTER_NODE_TYPE(EvaluateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->Print(op->value); p->stream << "\n"; }); // BufferStore BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { return BufferStore(buffer, value, indices, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->buffer->name << "["; for (size_t i = 0; i < op->indices.size(); ++i) { p->Print(op->indices[i]); if (i < op->indices.size() - 1) p->stream << ", "; } p->stream << "]"; p->stream << " = "; p->Print(op->value); p->stream << '\n'; }); // BufferRealize BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) { data_ = make_object(buffer, bounds, condition, body, span); } TVM_REGISTER_GLOBAL("tir.BufferRealize") .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); TVM_REGISTER_NODE_TYPE(BufferRealizeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << "buffer_realize " << op->buffer->name << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); p->stream << ", "; p->Print(op->bounds[i]->extent); p->stream << "]"; if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); } p->stream << " {\n"; p->indent += 2; p->Print(op->body); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); // BufferRegion BufferRegion::BufferRegion(Buffer buffer, Array region) { CHECK_EQ(buffer->shape.size(), region.size()) << "The dimension between " << buffer << " and region " << region << " mismatched, the buffer is " << buffer; ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->region = std::move(region); data_ = std::move(node); } BufferRegion BufferRegion::FullRegion(Buffer buffer) { Array region; for (PrimExpr extent : buffer->shape) { region.push_back(Range::FromMinExtent(0, extent)); } return BufferRegion(buffer, region); } BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { Array region; for (const PrimExpr& index : indices) { region.push_back(Range::FromMinExtent(index, 1)); } return BufferRegion(buffer, region); } TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { return BufferRegion(buffer, region); }); TVM_REGISTER_NODE_TYPE(BufferRegionNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->buffer->name; p->stream << "["; for (size_t i = 0; i < op->region.size(); ++i) { const auto& range = op->region[i]; p->Print(range->min); if (!is_one(range->extent)) { p->stream << ":"; p->Print(range->min + range->extent); } if (i != op->region.size() - 1) p->stream << ", "; } p->stream << "]"; }); // MatchBufferRegion MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { const Buffer& source_buffer = source->buffer; arith::Analyzer analyzer; // Check scope and dtype CHECK_EQ(buffer.scope(), source_buffer.scope()) << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " << source_buffer.scope(); CHECK_EQ(buffer->dtype, source_buffer->dtype) << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " << source_buffer->dtype; // Check data_alignment CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) << "Trying to match buffer to another one with lower alignment requirement " << " required_alignment=" << buffer->data_alignment << ", provided_alignment=" << source_buffer->data_alignment; // Check BufferType. AutoBroadcast is not allowed for now. CHECK(buffer->buffer_type == BufferType::kDefault && source_buffer->buffer_type == BufferType::kDefault) << "AutoBroadcast is not allowed in MatchBuffer"; // Validate shape CHECK(source->region.size() >= buffer->shape.size()) << "Dimension of source Region expected to be larger or equal than target buffer shape, but " "got " << source->region.size() << " vs. " << buffer->shape.size(); size_t offset = source->region.size() - buffer->shape.size(); for (size_t i = 0; i < offset; ++i) { CHECK(analyzer.CanProve(source->region[i]->extent == 1)) << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; } for (size_t i = 0; i < buffer->shape.size(); ++i) { const Range& source_range = source->region[i + offset]; const PrimExpr& buffer_shape = buffer->shape[i]; if (!buffer_shape->IsInstance()) { CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) << "The dimension mismatched between source region and target buffer shape, got " << source_range->extent << " vs. " << buffer_shape << "."; } } // Note that we do not check elem_offset and strides in this function // Construction ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->source = std::move(source); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) { return MatchBufferRegion(buffer, source); }); TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); // Block Block::Block(Array iter_vars, Array reads, Array writes, String name_hint, Stmt body, Optional init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { ObjectPtr node = make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); node->writes = std::move(writes); node->name_hint = std::move(name_hint); node->body = std::move(body); node->init = std::move(init); node->alloc_buffers = std::move(alloc_buffers); node->match_buffers = std::move(match_buffers); node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Block") .set_body_typed([](Array iter_vars, Array reads, Array writes, String name_hint, Stmt body, Optional init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, annotations, span); }); TVM_REGISTER_NODE_TYPE(BlockNode); void PrintBlockTitle(const BlockNode* op, ReprPrinter* p) { p->stream << "block " << op->name_hint << "("; for (size_t i = 0; i < op->iter_vars.size(); i++) { p->Print(op->iter_vars[i]); if (i < op->iter_vars.size() - 1) p->stream << ", "; } p->stream << ")"; } void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) { // print read/write regions p->PrintIndent(); p->stream << "reads("; p->Print(op->reads); p->stream << ")\n"; p->PrintIndent(); p->stream << "writes("; p->Print(op->writes); p->stream << ")\n"; // Print alloc_buffers for (const auto& alloc_buf : op->alloc_buffers) { p->PrintIndent(); p->stream << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { if (i > 0) p->stream << ", "; p->Print(alloc_buf->shape[i]); } p->stream << "])\n"; } // Print match_buffer_regions for (const auto& match_buf : op->match_buffers) { p->Print(match_buf); } if (!op->annotations.empty()) { p->PrintIndent(); p->stream << "annotations(" << op->annotations << ")\n"; } } void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { // Print init if (op->init.defined()) { p->PrintIndent(); p->stream << "with init() {\n"; p->indent += 2; p->Print(op->init.value()); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; } // Print body p->Print(op->body); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); PrintBlockTitle(op, p); p->stream << " {\n"; p->indent += 2; // Print block elements (e.g. reads/writes, etc) PrintBlockSignature(op, p); // Print block init and body PrintBlockBody(op, p); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); // BlockRealize BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block block, Span span) { CHECK_EQ(block->iter_vars.size(), values.size()) << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression"; ObjectPtr node = make_object(); node->iter_values = std::move(values); node->predicate = std::move(predicate); node->block = std::move(block); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BlockRealize") .set_body_typed([](Array iter_values, PrimExpr predicate, Block block, Span span) { return BlockRealize(iter_values, predicate, block, span); }); TVM_REGISTER_NODE_TYPE(BlockRealizeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); auto* block_op = op->block.get(); p->PrintIndent(); PrintBlockTitle(block_op, p); p->stream << " {\n"; p->indent += 2; // Print binding iter_values for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { p->PrintIndent(); p->stream << "bind("; p->Print(block_op->iter_vars[i]->var); p->stream << ", "; p->Print(op->iter_values[i]); p->stream << ")\n"; } // Print predicate if (!is_one(op->predicate)) { p->PrintIndent(); p->stream << "where("; p->Print(op->predicate); p->stream << ")\n"; } // Print block elements (e.g. reads/writes, etc) PrintBlockSignature(block_op, p); // Print block init and body PrintBlockBody(block_op, p); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; }); PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); return tir::Call(dtype, op, {}, span); } TVM_REGISTER_OP("tir.type_annotation") .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace tir } // namespace tvm