/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tir_text_printer.cc * \brief Printer to print out the IR text format * that can be parsed by a parser. */ #include #include #include #include #include #include #include #include #include #include #include #include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" namespace tvm { namespace tir { Doc TIRTextPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("(nullptr)"); if (node->IsInstance()) { return VisitStmt(Downcast(node)); } else if (node->IsInstance()) { return Doc::Text("?"); } else if (node->IsInstance()) { return VisitExpr(Downcast(node)); } else if (node->IsInstance()) { return VisitType(Downcast(node)); } else if (node->IsInstance()) { return PrintPrimFunc(Downcast(node)); } else if (node->IsInstance()) { return PrintIRModule(Downcast(node)); } else if (node->IsInstance()) { return PrintArray(node.as()); } else if (node->IsInstance()) { return PrintIterVar(node.as()); } else if (node->IsInstance()) { return PrintRange(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); } else if (node->IsInstance()) { return PrintProducer(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { return PrintBufferRegion(node.as()); } else if (node->IsInstance()) { return Doc::Text(node.as()->ToDebugString()); } else { return this->meta_->GetMetaNode(node); } } Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { const auto* op = prim_func.operator->(); const auto& signature = op->func_type_annotation(); // collect Meta in DictAttr if (prim_func->attrs.defined()) { for (const auto& it : prim_func->attrs->dict) { meta_collector_.Collect(it.second); } } // collect buffers in buffer_map memo_var_.clear(); memo_buf_.clear(); for (const auto& it : op->buffer_map) { memo_buf_[it.second] = AllocBuf(it.second); } // print PrimFunc Doc doc; doc << "primfn" << "("; // print params and its type annotation std::vector params; for (const auto& param : op->params) { params.push_back(Print(param)); } Doc sep; doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")"; // print return type doc << " -> " << Print(signature->ret_type); // print attr Doc attr_doc; std::vector attr_docs; if (prim_func->attrs.defined()) { for (const auto& it : op->attrs->dict) { attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); } attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; doc << Doc::Indent(2, attr_doc); } // print all the buffers in the tree if (memo_buf_.size() != 0) { Doc buffer_doc; std::vector buffer_docs; for (const auto& it : memo_buf_) { const auto& buf = it.first; buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf))); } buffer_doc << Doc::NewLine() << "buffers = {"; buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); doc << Doc::Indent(2, buffer_doc) << "}"; } if (op->buffer_map.size() != 0) { // print buffer_map std::vector buffer_map_doc; for (const auto& it : op->buffer_map) { buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); } doc << Doc::Indent( 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } doc << PrintBody(op->body); return doc; } Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { const auto* op = module.operator->(); Doc doc; Doc body; body << Doc::NewLine(); std::vector functions; for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { if ((*it).second.as()) { functions.push_back(Print((*it).second)); } } body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); doc << Doc::Indent(0, body); return doc; } Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { Doc doc; doc << '['; for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { doc << ", "; } doc << Print(op->at(i)); } doc << ']'; return doc; } Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) { Doc doc; doc << "IterVar(" << Print(op->var); if (op->dom.defined()) { doc << ", [" << Print(op->dom) << "], "; } else { doc << ", " << Print(op->dom) << ", "; } doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "; doc << Doc::StrLiteral(op->thread_tag) << ")"; return doc; } Doc TIRTextPrinter::PrintRange(const RangeNode* op) { return Print(op->min) << ":" << Print(op->min + op->extent); } Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { const Buffer& buffer = GetRef(op); if (meta_->InMeta(buffer)) { return meta_->GetMetaNode(buffer); } else if (memo_buf_.count(buffer)) { return memo_buf_[buffer]; } else { memo_buf_[buffer] = AllocBuf(buffer); return BufferNode2Doc(op, memo_buf_[buffer]); } } Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) { const DataProducer& prod = GetRef(op); if (meta_->InMeta(prod)) { return meta_->GetMetaNode(prod); } else if (memo_producer_.count(prod)) { return memo_producer_[prod]; } else { memo_producer_[prod] = AllocProducer(prod); return DataProducerNode2Doc(op, memo_producer_[prod]); } } Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " << Print(buf->strides); if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } if (GetRef(buf).scope() != "global") { doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } if (buf->data_alignment != 128) { doc << ", align=" << buf->data_alignment; } if (buf->offset_factor != 1) { doc << ", offset_factor=" << buf->offset_factor; } if (buf->buffer_type != 1) { doc << ", type=" << Doc::StrLiteral("auto"); } return doc << ")"; } Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) { return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", " << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")"; } Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; doc << Print(op->buffer) << "["; for (size_t i = 0; i < op->region.size(); ++i) { if (i != 0) { doc << ", "; } const auto& range = op->region[i]; if (!is_one(range->extent)) { doc << Print(range->min) << ":" << Print(range->min + range->extent); } else { doc << Print(range->min); } } doc << "]"; return doc; } Doc TIRTextPrinter::VisitExprDefault_(const Object* op) { return this->meta_->GetMetaNode(GetRef(op)); } Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) { return this->meta_->GetMetaNode(GetRef(op)); } Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) { return PrintConstScalar(op->dtype, op->value); } Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { return PrintConstScalar(op->dtype, op->value); } Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { Doc doc; doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { const Var& var = GetRef(op); return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); } #define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \ Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ Doc doc; \ doc << "(" << Print(op->a) << OpString; \ doc << Print(op->b) << ")"; \ return doc; \ } TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AddNode, " + ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(SubNode, " - ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(MulNode, "*") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(DivNode, " / ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(ModNode, " % ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(EQNode, " == ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(NENode, " != ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LTNode, " < ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LENode, " <= ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GTNode, " > ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GENode, " >= ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AndNode, " && ") TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OrNode, " || ") Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) { Doc doc; doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) { Doc doc; doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const MinNode* op) { Doc doc; doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) { Doc doc; doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const NotNode* op) { Doc doc; doc << "!" << Print(op->a); return doc; } Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) { Doc doc; doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " << Print(op->false_value) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { Doc doc; doc << Print(op->buffer) << Print(op->indices); return doc; } Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { // TODO(tvm-team): consider make a better text format for producer. Doc doc; doc << op->producer->GetNameHint() << Print(op->indices); return doc; } Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { Doc doc; doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) << "]"; if (!is_one(op->predicate)) { doc << " if " << Print(op->predicate); } return doc; } Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { Doc doc; doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { Doc doc; doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { Doc doc; doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); return doc; } Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; if (auto* ptr_op = op->op.as()) { doc << "@" << Doc::Text(ptr_op->name) << "("; } else { // TODO(bohan): Print out the name by he global var in the module. auto* op_gvar = op->op.as(); ICHECK(op_gvar != nullptr); doc << "@" << Doc::Text(op_gvar->name_hint) << "("; } std::vector args; for (const auto& arg : op->args) { args.push_back(Print(arg)); } doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) { Doc doc; doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { Doc doc; doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) << ", " << op->value_index << ", " << Print(op->init) << ")"; return doc; } Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body); return doc; } Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; meta_collector_.Collect(op->node); doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = " << Print(op->value); if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { doc << ";" << Doc::NewLine() << Print(op->body); } return doc; } Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine() << Print(op->body); return doc; } Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) { Doc doc; doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); if (!is_one(op->predicate)) { doc << " if " << Print(op->predicate); } return doc; } Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { Doc doc; doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); return doc; } Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) { Doc doc; doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value); return doc; } Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc doc; doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " << Print(op->condition) << PrintBody(op->body) << ")"; return doc; } Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) { Doc doc; doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", " << Print(op->condition) << ", " << PrintBody(op->body) << ")"; return doc; } Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); doc << "allocate(" << Print(op->buffer_var) << ", "; doc << PrintDType(op->dtype) << ", "; doc << Print(op->extents) << "), storage_scope = " << scope; if (!op->annotations.empty()) { std::vector attr_docs; for (const auto& it : op->annotations) { attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); } doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})"; } if (!is_one(op->condition)) { doc << " if " << Print(op->condition); } if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { doc << ";" << Doc::NewLine() << Print(op->body); } return doc; } Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); if (!is_one(op->condition) && op->else_case.defined()) { doc << " else" << PrintBody(op->else_case); } return doc; } Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { std::vector stmts; Doc seq_doc, doc; for (Stmt stmt : op->seq) { seq_doc << Doc::NewLine() << Print(stmt); } doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}"; return doc; } Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { Doc doc; doc << Print(op->value); return doc; } Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " << Print(op->min + op->extent) << ")"; if (op->kind != ForKind::kSerial) { doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); } doc << PrintBody(op->body); return doc; } Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) { Doc doc; doc << "while (" << Print(op->condition) << ")"; doc << PrintBody(op->body); return doc; } Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { Doc doc; doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; return doc; } Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); // print block name and block vars Doc doc; doc << "block(["; std::vector block_var_docs; for (const auto& iter_var : block_op->iter_vars) { Doc block_var_doc; if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { block_var_doc << Print(iter_var->dom->extent); } else { block_var_doc << "tir."; switch (iter_var->iter_type) { case kDataPar: block_var_doc << "range"; break; case kCommReduce: block_var_doc << "reduce_axis"; break; case kOrdered: block_var_doc << "scan_axis"; break; case kOpaque: block_var_doc << "opaque_axis"; break; default: LOG(FATAL) << "Unknown block var iter type"; break; } block_var_doc << "(" << Print(iter_var->dom->min) << ", " << Print(iter_var->dom->min + iter_var->dom->extent) << ")"; } block_var_docs.push_back(block_var_doc); } doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], "; doc << Doc::StrLiteral(block_op->name_hint) << ")"; std::vector block_var_names; for (const auto& iter_var : block_op->iter_vars) { Doc block_var_name; AllocVar(iter_var->var); block_var_names.push_back(Print(iter_var->var)); } if (!block_var_names.empty()) { doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; } doc << " {"; Doc block_attr_doc; // print predicate, binding, read/write tensor region, annotations if (!is_one(op->predicate)) { block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")"; } for (size_t i = 0; i < block_op->iter_vars.size(); ++i) block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) << ")"; block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { std::vector attr_docs; for (const auto& it : block_op->annotations) { attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); } block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})"; } // print body Doc body; body << Doc::NewLine(); for (const auto& alloc_buf : block_op->alloc_buffers) { body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype) << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; init_block << "with init()"; init_block << PrintBody(block_op->init.value()); body << init_block << Doc::NewLine(); } body << Print(block_op->body); doc << Doc::Indent(2, block_attr_doc << body); return doc; } Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) { Doc doc; doc << PrintDType(node->dtype); return doc; } Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) { Doc doc; doc << "Pointer("; if (!node->storage_scope.empty()) { doc << node->storage_scope << " "; } doc << Print(node->element_type) << ")"; return doc; } Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { std::vector fields; for (Type field : node->fields) { fields.push_back(Print(field)); } Doc doc; doc << "(" << Doc::Concat(fields); // conform to python tuple format (1,) if (node->fields.size() == 1) { doc << ","; } return doc << ")"; } Doc TIRTextPrinter::PrintDType(DataType dtype) { return Doc::Text(runtime::DLDataType2String(dtype)); } template Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { Doc doc; std::ostringstream os; os << data; if (dtype == DataType::Int(32)) { doc << Doc::Text(os.str()); } else { if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) { doc << ((data == 1) ? "True" : "False"); return doc; } doc << Doc::Text(os.str()); switch (dtype.code()) { case kDLInt: doc << "i"; break; case kDLUInt: doc << "u"; break; case kDLFloat: doc << "f"; break; } doc << Doc::Text(std::to_string(dtype.bits())); if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); } return doc; } Doc TIRTextPrinter::GetUniqueName(std::string prefix) { // std::replace(prefix.begin(), prefix.end(), '.', '_'); std::string unique_prefix = prefix; auto it = name_alloc_map_.find(prefix); if (it != name_alloc_map_.end()) { while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { } } name_alloc_map_[unique_prefix] = 0; return Doc::Text(unique_prefix); } Doc TIRTextPrinter::AllocVar(const Var& var) { const auto& it = memo_var_.find(var); if (it != memo_var_.end()) { return it->second; } std::string name = var->name_hint.operator std::string(); if (name.length() == 0 || !std::isalpha(name[0])) { name = "v" + name; } Doc val = GetUniqueName(name); memo_var_[var] = val; return val << ": " << Print(GetType(var)); } Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { const auto& it = memo_buf_.find(buffer); if (it != memo_buf_.end()) { return it->second; } std::string name = buffer->name; if (name.length() == 0 || !std::isalpha(name[0])) { name = "buf_" + name; } Doc val = GetUniqueName(name); memo_buf_[buffer] = val; return val; } Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) { const auto& it = memo_producer_.find(producer); if (it != memo_producer_.end()) { return it->second; } std::string name = producer->GetNameHint(); if (name.length() == 0 || !std::isalpha(name[0])) { name = "tensor_" + name; } Doc val = GetUniqueName(name); memo_producer_[producer] = val; return val; } Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { seq = vec[0]; for (size_t i = 1; i < vec.size(); i++) { seq << sep << vec[i]; } } return seq; } Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { Doc doc; if (body->IsInstance()) return Print(body); doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}"; return doc; } bool TIRTextPrinter::GetVarName(Var v, std::string* s) { auto it = memo_var_.find(v); if (it == memo_var_.end()) { return false; } *s = it->second.str(); return true; } } // namespace tir } // namespace tvm