/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file text_printer.h * \brief Printer to print out the unified IR text format * that can be parsed by a parser. */ #ifndef TVM_PRINTER_TEXT_PRINTER_H_ #define TVM_PRINTER_TEXT_PRINTER_H_ #include <tvm/ir/module.h> #include <tvm/ir/type_functor.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/pattern_functor.h> #include <tvm/tir/expr_functor.h> #include <tvm/tir/function.h> #include <tvm/tir/op.h> #include <tvm/tir/stmt_functor.h> #include <tvm/tir/var.h> #include <string> #include <unordered_map> #include <unordered_set> #include <vector> #include "../ir/attr_functor.h" #include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" namespace tvm { class TextPrinter; } // namespace tvm namespace tvm { namespace relay { class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>, public PatternFunctor<Doc(const Pattern&)>, public TypeFunctor<Doc(const Type&)>, public AttrFunctor<Doc(const ObjectRef&)> { public: explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} Doc VisitExpr(const Expr& expr) override; virtual Doc VisitLeaf(const Expr& expr); virtual bool CheckVisited(const Expr& expr); /*! * \brief Print additional info about expr in comment. * \param expr The expression. */ Doc PrintOptionalInfo(const Expr& expr); // indent a new body Doc PrintBody(const ObjectRef& node, int indent = 2); // create a new scope by creating a new printer object. This allows temp var // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const ObjectRef& node); Doc PrintFinal(const ObjectRef& node); /*! * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence * of key=value entries, if any. */ void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key); /*! * \brief Returns \p attrs printed as a sequence of key=value entries, if any. * This is used for call attributes. */ std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op); /*! * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. * This is used for function definition attributes. */ std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs); std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs); /*! * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. */ Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); /*! * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. */ Doc PrintAttrsAsAttributeValue(const Attrs& attrs); /*! * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. */ Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map); Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); Doc TempVar(int n); Doc AllocTemp(); /*! * \brief get a unique name with the corresponding prefix * \param prefix The prefix of the name * \return The returned name. */ Doc GetUniqueName(const std::string& prefix); Doc Print(Kind k); /*! * \brief Allocate name to a type variable. * \param var The input type variable. * \return The corresponding name. */ Doc AllocTypeVar(const TypeVar& var); /*! * \brief Allocate name to a variable. * \param var The input variable. * \return The corresponding name. */ Doc AllocVar(const Var& var); bool IsUnique(const Expr& expr); bool AlwaysInline(const Expr& expr); Doc PrintFunc(const Doc& prefix, const relay::Function& fn); Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); Doc PrintMod(const IRModule& mod); //------------------------------------ // Overload of Expr printing functions //------------------------------------ Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); // Should only be triggered when op is a free variable being visited for the // first time. Doc VisitExpr_(const VarNode* op) final; /*! * \brief special method to print out const scalar * \param dtype The data type * \param value The value to be printed. */ template <typename T> static Doc ScalarLiteral(DataType dtype, const T& value); Doc VisitExpr_(const ConstantNode* op) final; Doc VisitExpr_(const TupleNode* op) final; Doc VisitExpr_(const TupleGetItemNode* op) final; Doc VisitExpr_(const IfNode* op) final; Doc VisitExpr_(const LetNode* op) final; Doc VisitExpr_(const FunctionNode* op) final; Doc VisitExpr_(const GlobalVarNode* op) final; Doc VisitExpr_(const OpNode* op) final; Doc VisitExpr_(const CallNode* op) final; Doc VisitExpr_(const RefCreateNode* op) final; Doc VisitExpr_(const RefReadNode* op) final; Doc VisitExpr_(const RefWriteNode* op) final; Doc VisitExpr_(const MatchNode* op) final; Doc PrintPattern(const Pattern& pattern, bool meta); Doc VisitPattern_(const PatternConstructorNode* p) final; Doc VisitPattern_(const PatternTupleNode* pt) final; Doc VisitPattern_(const PatternWildcardNode* pw) final; Doc VisitPattern_(const PatternVarNode* pv) final; Doc VisitExpr_(const ConstructorNode* n) final; //------------------------------------ // Overload of Type printing functions //------------------------------------ Doc PrintType(const Type& type, bool meta); Doc VisitTypeDefault_(const Object* node) final; Doc VisitType_(const TypeVarNode* node) final; Doc VisitType_(const GlobalTypeVarNode* node) final; Doc VisitType_(const TypeCallNode* node) final; Doc PrintDType(DataType dtype); Doc VisitType_(const TensorTypeNode* node) final; Doc VisitType_(const TupleTypeNode* node) final; Doc VisitType_(const FuncTypeNode* node) final; Doc VisitType_(const RelayRefTypeNode* node) final; Doc VisitType_(const TypeDataNode* node) final; //------------------------------------ // Overload of Attr printing functions //------------------------------------ Doc VisitAttrDefault_(const Object* op) final; Doc VisitAttr_(const ArrayNode* op) final; Doc VisitAttr_(const tir::IntImmNode* op) final; Doc VisitAttr_(const tir::FloatImmNode* op) final; Doc VisitAttr_(const tir::StringImmNode* op) final; private: /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector<Doc> doc_stack_{}; /*! \brief Set for introduced vars */ std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_; /*! \brief Set for exprs have been printed optional information */ std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_; /*! \brief Map for result and memo_ diffs for visited expression */ std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_; /*! \brief Map from Expr to Doc */ std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_; /*! \brief Map from Type to Doc */ std::unordered_map<Type, Doc, ObjectPtrHash, ObjectPtrEqual> memo_type_; /*! \brief Map from Type to Doc */ std::unordered_map<Pattern, Doc, ObjectPtrHash, ObjectPtrEqual> memo_pattern_; /*! \brief name allocation map */ std::unordered_map<std::string, int> name_alloc_map_; /*! \brief meta data context */ TextMetaDataContext* meta_; /*! \brief counter of temporary variable */ size_t temp_var_counter_{0}; /*! \brief whether the printer is currently in an ADT definition */ bool in_adt_def_; /*! \brief arena for dependency graph */ support::Arena arena_; /*! \brief dependency graph of the expr */ DependencyGraph dg_; class AttrPrinter; friend class AttrPrinter; friend class tvm::TextPrinter; }; } // namespace relay } // namespace tvm namespace tvm { namespace tir { /*! * \brief Meta node collector * If we decide to put some node into meta, then all the sub-nodes inside * it need to be put in meta as well, since when parsing we need to know * whether two refs are the same */ class MetaCollector : public StmtExprVisitor { public: explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} void Collect(const ObjectRef& n) { // these nodes can be print directly(StringLiteral or use identifier to identify) if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) { return; } if (n->IsInstance<StmtNode>()) { VisitStmt(Downcast<Stmt>(n)); } else if (n->IsInstance<PrimExprNode>()) { VisitExpr(Downcast<PrimExpr>(n)); } } void VisitStmt(const Stmt& n) override { meta_->GetMetaNode(n); StmtVisitor::VisitStmt(n); } void VisitExpr(const PrimExpr& n) override { meta_->GetMetaNode(n); ExprVisitor::VisitExpr(n); } private: TextMetaDataContext* meta_; }; class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>, public ExprFunctor<Doc(const PrimExpr&)>, public TypeFunctor<Doc(const Type&)> { public: explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} /*! \brief Print the node */ Doc Print(const ObjectRef& node); /*! \brief Place into `s` the name used in the preceding Print call for `v`. * \param v Var instance to check. Must point to a VarNode visited by Print. * \param s String to receive the name. * \return true when a name re-mapping was found. */ bool GetVarName(::tvm::tir::Var v, std::string* s); private: /*! \brief whether show meta data */ bool show_meta_; /*! \brief meta data context */ TextMetaDataContext* meta_; /*! \brief meta collector */ MetaCollector meta_collector_; /*! \brief Map from Var to Doc */ std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_; /*! \brief Map from Buffer to Doc */ std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_; /*! \brief name allocation map */ std::unordered_map<std::string, int> name_alloc_map_; friend class tvm::TextPrinter; Doc VisitExpr_(const IntImmNode* op) override; Doc VisitExpr_(const FloatImmNode* op) override; Doc VisitExpr_(const StringImmNode* op) override; Doc VisitExpr_(const CastNode* op) override; Doc VisitExpr_(const VarNode* op) override; Doc VisitExpr_(const AddNode* op) override; Doc VisitExpr_(const SubNode* op) override; Doc VisitExpr_(const MulNode* op) override; Doc VisitExpr_(const DivNode* op) override; Doc VisitExpr_(const ModNode* op) override; Doc VisitExpr_(const FloorDivNode* op) override; Doc VisitExpr_(const FloorModNode* op) override; Doc VisitExpr_(const MinNode* op) override; Doc VisitExpr_(const MaxNode* op) override; Doc VisitExpr_(const EQNode* op) override; Doc VisitExpr_(const NENode* op) override; Doc VisitExpr_(const LTNode* op) override; Doc VisitExpr_(const LENode* op) override; Doc VisitExpr_(const GTNode* op) override; Doc VisitExpr_(const GENode* op) override; Doc VisitExpr_(const AndNode* op) override; Doc VisitExpr_(const OrNode* op) override; Doc VisitExpr_(const NotNode* op) override; Doc VisitExpr_(const SelectNode* op) override; Doc VisitExpr_(const BufferLoadNode* op) override; Doc VisitExpr_(const ProducerLoadNode* op) override; Doc VisitExpr_(const LoadNode* op) override; Doc VisitExpr_(const RampNode* op) override; Doc VisitExpr_(const BroadcastNode* op) override; Doc VisitExpr_(const LetNode* op) override; Doc VisitExpr_(const CallNode* op) override; Doc VisitExpr_(const ShuffleNode* op) override; Doc VisitExpr_(const ReduceNode* op) override; Doc VisitExprDefault_(const Object* op) override; Doc VisitStmt_(const LetStmtNode* op) override; Doc VisitStmt_(const AttrStmtNode* op) override; Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; Doc VisitStmt_(const ProducerStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const ForNode* op) override; Doc VisitStmt_(const WhileNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; Doc VisitStmtDefault_(const Object* op) override; Doc VisitType_(const PrimTypeNode* node) override; Doc VisitType_(const PointerTypeNode* node) override; Doc VisitType_(const TupleTypeNode* node) override; Doc PrintIRModule(const IRModule& module); Doc PrintPrimFunc(const PrimFunc& primFunc); Doc PrintArray(const ArrayNode* op); Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); Doc PrintProducer(const DataProducerNode* op); Doc BufferNode2Doc(const BufferNode* op, Doc doc); Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc PrintBufferRegion(const BufferRegionNode* op); /*! * \brief special method to print out data type * \param dtype The data type */ static Doc PrintDType(DataType dtype); /*! * \brief special method to print out const scalar * \param dtype The data type * \param data The pointer to hold the data. */ template <typename T> static Doc PrintConstScalar(DataType dtype, const T& data); Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); Doc AllocProducer(const DataProducer& buffer); /*! * \brief special method to render vectors of docs with a separator * \param vec vector of docs * \param sep separator */ static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep); Doc PrintBody(const Stmt& body, bool indent = true); }; String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc<std::string(Stmt)> annotate); } // namespace tir } // namespace tvm namespace tvm { class TextPrinter { public: explicit TextPrinter(bool show_meta_data, const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate, bool show_warning = true) : show_meta_data_(show_meta_data), show_warning_(show_warning), annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), tir_text_printer_(show_meta_data, &meta_) {} /*! \brief whether show meta data */ bool show_meta_data_; /*! \brief whether show the meta data warning message */ bool show_warning_; /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief additional comment function */ runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; /*! \brief Relay Text Printer */ relay::RelayTextPrinter relay_text_printer_; /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } Doc PrintFinal(const ObjectRef& node) { Doc doc; if (node.defined() && node->IsInstance<IRModuleNode>()) { doc << PrintMod(Downcast<IRModule>(node)); } else if (node.defined() && (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() || node->IsInstance<tir::StmtNode>())) { doc << tir_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); } if (!meta_.empty()) { doc << Doc::NewLine(); if (show_meta_data_) { doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); } else if (show_warning_) { doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() << " * If you would like to see the full metadata section you can set the " << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() << " */"; } } return doc; } Doc PrintMod(const IRModule& mod); }; } // namespace tvm #endif // TVM_PRINTER_TEXT_PRINTER_H_