/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file parser.cc * \brief A parser for TVM IR. */ #include #include #include #include #include #include #include #include #include #include #include #include "./meta_ref.h" #include "./op_table.h" #include "./span_check.h" #include "./tokenizer.h" namespace tvm { namespace parser { using namespace relay; using Expr = relay::Expr; /*! \brief The meta table maps from type key to a sequence of objects. */ using MetaTable = Map>; using tvm::transform::CreateModulePass; using tvm::transform::PassContext; /*! \brief A helper for passing around spans with data structures with * no span field. */ template struct Spanned { T data; Span span; Spanned() = default; Spanned(const Spanned& other) = default; Spanned(T data, Span span) : data(data), span(span) {} }; /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * * This enables the parser to parse everything in one pass before * constructing the IRModule. */ struct GlobalFunc { GlobalVar global; Function function; GlobalFunc() : global(), function() {} GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {} GlobalFunc(const GlobalFunc& gfunc) { this->global = gfunc.global; this->function = gfunc.function; } }; /*! \brief A wrapper structure for capturing all top-level definitions * when parsing a module. */ struct Definitions { /*! \brief The set of global functions. */ std::vector funcs; /*! \brief The set of type definitions. */ std::vector types; // TODO(@jroesch): contain meta-table below }; /*! \brief A structure representing the semantic versioning information * for a Relay program. */ class SemVer { public: int major_version; int minor_version; int patch_version; SemVer() : major_version(0), minor_version(0), patch_version(0) {} SemVer(int major_version, int minor_version, int patch_version) : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {} SemVer(const SemVer& other) : major_version(other.major_version), minor_version(other.minor_version), patch_version(other.patch_version) {} }; /*! \brief A simple wrapper around a mapping from raw string names * to a TVM variable, type variable or other binder type. */ template struct Scope { /*! \brief The internal map. */ std::unordered_map name_map; }; /*! \brief A stack of scopes. * * In order to properly handle scoping we must maintain a stack of scopes. * * A stack allows users to write programs which contain repeated variable * names and to properly handle both nested scopes and removal of variables * when they go out of scope. * * This is the classic approach to lexical scoping. */ template class ScopeStack { private: std::vector> scope_stack; std::unordered_map free_vars; public: /*! \brief Adds a variable binding to the current scope. */ void Add(const std::string& name, const T& value) { if (!this->scope_stack.size()) { LOG(FATAL) << "internal issue"; } this->scope_stack.back().name_map.insert({name, value}); } void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } /*! \brief Looks up a variable name in the scope stack returning the matching variable * in most recent scope. */ T Lookup(const std::string& name) { for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) { auto it = scope->name_map.find(name); if (it != scope->name_map.end()) { return it->second; } } // Check if we bound a free variable declaration. auto it = free_vars.find(name); if (it != free_vars.end()) { return it->second; } return T(); } /*! \brief Adds a fresh scope. */ void PushStack() { this->scope_stack.push_back(Scope()); } /*! \brief Removes the most recent scope. */ void PopStack() { this->scope_stack.pop_back(); } }; struct DuplicateKeyError : public Error { explicit DuplicateKeyError(const std::string& msg) : Error(msg) {} }; /*! \brief A table of interning strings as global function and type names. */ template struct InternTable { /*! \brief The internal table mapping strings to a unique allocation. */ std::unordered_map table; DiagnosticContext* ctx; /*! \brief Add the unique allocation. */ void Add(const std::string& name, const T& t) { auto it = table.find(name); if (it != table.end()) { throw DuplicateKeyError("duplicate key name in intern table"); } else { table.insert({name, t}); } } /*! \brief Return the unique allocation. */ Optional Get(const std::string& name) const { auto it = table.find(name); if (it != table.end()) { return Optional(it->second); } else { return Optional(); } } }; GlobalVar AddOrGet(InternTable* table, const std::string& name) { auto var = table->Get(name); if (var) { return var.value(); } else { auto gvar = GlobalVar(name); table->Add(name, gvar); return gvar; } } GlobalTypeVar AddOrGet(InternTable* table, const std::string& name, TypeKind kind = TypeKind::kType) { auto var = table->Get(name); if (var) { auto tvar = var.value(); TypeKind& tvar_kind = const_cast(tvar->kind); tvar_kind = kind; return tvar; } else { auto gvar = GlobalTypeVar(name, kind); table->Add(name, gvar); return gvar; } } /*! \brief The parser class is the main interface to the parser. * the parser is not currently exposed beyond this .cc file. * * The parser is initialized with a diagnostic context, an * operator table, and a token stream. * * The rest of the internal state is used to map the human readable * form to in-memory IR representation. * * The main entry point to the parser are a set of parsing methods * such as `ParseModule` and `ParseExpr`. * * As with traditional recursive descent parsers the parsing methods * are factored recursively just as one would do with a formal language * grammar. * * You can view a recursive descent parser as a human friendly way to specify * a state machine, and thus this factoring is necessary as the 'state' of this * machine is the combination of the current parsing method and the next token. * * Parsing proceeds by matching a token and then dispatching to the appropriate * method to parse the next tokens in the stream. * * For example if we are parsing a type and encounter a "Tensor" token we switch * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`. * * Certain matches like this are unambiguous and proceed in a straight line fashion * once the initial token is found. Other parsing is more complex and requires some * tricks to correctly parse. * * For example when we find a '(' in an expression context, it may be part of * a tuple, the arguments to a call, or a parenthesized expression. The below code * disambiguate these cases by factoring expression parsing into a series of methods * which encode the parsing context and thus how to interpret the parenthesis. * * For more information one should be able to read the code in order starting with * `ParseModule` or `ParseExpr`. */ class Parser { public: /*! \brief The version that the parser is parsing. */ SemVer version; /*! \brief The IRModule we are building. */ IRModule module; /*! \brief The diagnostic context used for error reporting. */ DiagnosticContext diag_ctx; const Source& source; /*! \brief The current position in the token stream. */ int pos; /*! \brief The token stream for the parser. */ std::vector tokens; /*! \brief The configured operator table. */ OperatorTable op_table; /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */ bool ignore_whitespace; /*! \brief A global mapping for GlobalVar. */ InternTable global_names; /*! \brief A global mapping for type definitions. */ InternTable type_names; /*! \brief A global mapping for constructor names. */ InternTable ctors; /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */ std::unordered_map graph_ctx; /*! \brief The set of type scopes used for generics. */ ScopeStack type_scopes; /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; /*! \brief The metadata section. */ MetaTable meta_table; Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector tokens, OperatorTable op_table, MetaTable table) : module(module), diag_ctx(ctx), source(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true), meta_table(table) { InitializeGlobals(); InitializeTypeDefs(); } /*! If we are parsing into a module with previously loaded data types we need to * map constructor names and variable names in the global tables. */ void InitializeTypeDefs() { for (auto pair : this->module->type_definitions) { type_names.Add(pair.first->name_hint, pair.first); for (auto ctor : pair.second->constructors) { ctors.Add(ctor->name_hint, ctor); } } } void InitializeGlobals() { for (auto pair : this->module->functions) { global_names.Add(pair.first->name_hint, pair.first); } } /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ Token Peek() { // For now we ignore all whitespace tokens and comments. // We can tweak this behavior later to enable white space sensitivity in the parser. while (pos < static_cast(tokens.size()) && ignore_whitespace && (tokens.at(pos)->token_type == TokenType::kWhitespace || tokens.at(pos)->token_type == TokenType::kNewline || tokens.at(pos)->token_type == TokenType::kLineComment || tokens.at(pos)->token_type == TokenType::kComment)) { pos++; } if (pos < static_cast(tokens.size())) { return Token(this->tokens.at(pos)); } else { return Token::Null(); } } /*! \brief Lookahead by N tokens. * \param n The number of tokens to lookahead. * \return The Nth token. */ Token Lookahead(int n) { ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1"; // We intend to skip n - 1 tokens, then return the nth. auto old_pos = pos; for (int i = 0; i < n - 1; i++) { Peek(); pos++; } auto tok = Peek(); pos = old_pos; return tok; } /*! \brief Consume a token, this method is the lowest level way to consume a token * and will not ignore white space or look ahead in anyway. * * /param token_type The token type to match. */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span) << "expected a " << Pretty(token_type) << " found " << Pretty(Peek()->token_type)); } pos++; } /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such * as whitespace or comments returning the first meaningful token. * * We then try and consume the requested token, this will trigger an error if the * current token does not match the token_type. */ Token Match(const TokenType& token_type) { auto tok = Peek(); Consume(token_type); return tok; } /*! Conditionally consume a token when it matches, this will never trigger an error * as we guard against consuming the token before we do. * * Useful for matching optional tokens, effectively looksahead by one. */ bool WhenMatch(const TokenType& token_type) { VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek(); if (Peek()->token_type == token_type) { Consume(token_type); return true; } else { return false; } } /* \brief Add a graph binding to the parsing context * * For example if we parse %0 = add(...), map 0 -> add(...), etc. */ void AddGraphBinding(const Token& token, const Expr& expr) { auto graph_no = token.ToNumber(); this->graph_ctx.insert({graph_no, expr}); } /* \brief Lookup a previously bound graph variable. * * Note: we take tokens in all lookup methods so that we * that we can do error reporting based on token location. */ Expr LookupGraphBinding(const Token& token) { auto graph_no = token.ToNumber(); return this->graph_ctx.at(graph_no); } /*! \brief Bind a local variable in the expression scope. * * "x" -> Var("x"), these are needed to map from the raw string names * to unique variable nodes. */ Var BindVar(const std::string& name, const relay::Type& type_annotation) { auto var = Var(name, type_annotation); this->expr_scopes.Add(name, var); return var; } /*! \brief Bind a local variable in the expression scope. * * "x" -> Var("x"), these are needed to map from the raw string names * to unique variable nodes. */ Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { auto var = Var(name, type_annotation); this->expr_scopes.AddFreeVar(name, var); return var; } /*! \brief Bind a type variable in the type scope. * * "A" -> TypeVar("A", ...), these are needed to map from raw string names * to unique type variable nodes. */ TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) { auto type_var = TypeVar(name, type_kind); this->type_scopes.Add(name, type_var); return type_var; } /*! \brief Lookup a variable in the expression scope. * * Note: all lookup methods take tokens intentionally for error reporting information. */ Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { diag_ctx.Emit(Diagnostic::Error(local->span) << "this local variable has not been previously declared"); } return var; } /*! \brief Lookup a variable in the type scope. * * Note: all lookup methods take tokens intentionally for error reporting information. */ TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); return var; } /*! \brief Add an expression scope to the scope stack. */ void PushScope() { this->expr_scopes.PushStack(); } /*! \brief Remove N expression scopes from the scope stack. */ void PopScopes(int n) { for (int i = 0; i < n; i++) { this->expr_scopes.PopStack(); } } /*! \brief Add an type scope to the scope stack. */ void PushTypeScope() { this->type_scopes.PushStack(); } /*! \brief Remove N type scopes from the scope stack. */ void PopTypeScopes(int n) { for (int i = 0; i < n; i++) { this->type_scopes.PopStack(); } } /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { if (token->token_type == TokenType::kInteger) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; int64_t i = Downcast(token->data); if (i > std::numeric_limits::max()) { auto dtype = String2DLDataType("int64"); auto data = NDArray::Empty({}, dtype, dev); auto array = reinterpret_cast(data->data); // revisit this, literal node issue. array[0] = i; return data; } else { auto dtype = String2DLDataType("int32"); auto data = NDArray::Empty({}, dtype, dev); auto array = reinterpret_cast(data->data); // revisit this, literal node issue. array[0] = i; return data; } } else if (token->token_type == TokenType::kFloat) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto float_imm = Downcast(token->data); auto data = NDArray::Empty({}, float_imm->dtype, dev); auto array = reinterpret_cast(data->data); // revisit this, literal node issue. // TODO(@jroesch): bounds checking float value = float_imm->value; array[0] = value; return data; } else { LOG(FATAL) << "internal error: should only call this function on numeric tokens"; return NDArray(); } } /*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */ NDArray BooleanToNDarray(bool value) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("bool"); auto data = NDArray::Empty({}, dtype, dev); auto array = reinterpret_cast(data->data); array[0] = value; return data; } [[noreturn]] void ParseError(const Token& token, const std::string& msg) { throw std::runtime_error(msg); } /*! \brief A parsing helper for a bracketed expression . */ template R Bracket(TokenType open, TokenType close, std::function parser) { Match(open); R result = parser(); Match(close); return result; } /*! \brief Parse `(` parser() `)`. */ template R Parens(std::function parser) { return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); } /*! \brief Parse `{` parser() `}`. */ template R Block(std::function parser) { return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } template R WithSpan(std::function parser) { auto start_span = Peek()->span; VLOG(9) << "WithSpan: start_span = " << start_span; R ast = parser(); if (ast.defined()) { // The token at the head of the stream is now 1 past where we parsed. So we find its start // position as its start and end, so that when we merge we only grow the spanned region // to the start of the current stream. auto span_pos = pos - 1; while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace || tokens.at(span_pos)->token_type == TokenType::kNewline || tokens.at(span_pos)->token_type == TokenType::kLineComment || tokens.at(span_pos)->token_type == TokenType::kComment)) { span_pos--; } auto end_token = tokens.at(span_pos); VLOG(9) << "WithSpan: end_span = " << end_token->span; ast->span = start_span.Merge(end_token->span); } return ast; } struct MetaRef { std::string type_key; uint64_t node_index; Span span; MetaRef(std::string type_key, uint64_t node_index, Span span) : type_key(type_key), node_index(node_index), span(span) {} }; MetaRef MetaRefFromToken(const Token& tok) { Call ref = Downcast(tok->data); auto attrs = ref->attrs.as(); auto type_key = attrs->node_type_key; auto index = attrs->node_index; return MetaRef(type_key, index, ref->span); } /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` * the second, and so on. */ ObjectRef ParseMetaRef() { auto meta_ref_tok = Match(TokenType::kMetaReference); auto meta_ref = MetaRefFromToken(meta_ref_tok); auto it = this->meta_table.find(meta_ref.type_key); if (it != this->meta_table.end()) { auto nodes = (*it).second; if (meta_ref.node_index < nodes.size()) { return nodes[meta_ref.node_index]; } else { this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) << "the node index `" << meta_ref.node_index << "` is out of bounds for `" << meta_ref.type_key << "`"); return ObjectRef(); } } else { this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) << "no entry in the meta table for `" << meta_ref.type_key << "`"); return ObjectRef(); } } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and * ending with a stop token. * * The simple form being ( )* . * * This also provides a fourth argument which is allowed to run when the sequence which matches * the inner sequence can not proceed. * * This is useful for parsing things like attributes which don't match the standard expression * parsers but are contained within the stop token. */ template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) << " stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream // we must parse leftovers, then match a stop token. if (before_stop) { auto did_parse = before_stop(); if (did_parse) { Match(stop); return {}; } } // This is the case in which we find an empty arguments lists and no leftovers. if (WhenMatch(stop)) { return Array(); } else { VLOG(9) << "Parser::ParseSequence: parse first"; auto data = parse(); Array elements = {data}; if (WhenMatch(stop)) { return elements; // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { while (true) { VLOG(9) << "Parser::ParseSequence: parse element"; if (WhenMatch(stop)) { break; } else { // If before stop is if (before_stop) { auto did_parse = before_stop(); if (did_parse) { Match(stop); return elements; } } auto data = parse(); WhenMatch(sep); elements.push_back(data); } } return elements; } else { auto next = Peek(); this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type)); return Array(nullptr); } } } /*! \brief Parse a full IRModule. */ IRModule ParseModule() { // Parse the semver header at the top of the module. this->version = ParseSemVer(); // Parse the definitions. auto defs = ParseDefinitions(); // Parse the metadata section at the end. auto metadata = ParseMetadata(); Match(TokenType::kEndOfFile); for (auto type_def : defs.types) { module->AddTypeDef(type_def->header, type_def); } for (auto func : defs.funcs) { module->Add(func.global, func.function, true); } return module; } /*! \brief Parse the semantic versioning header. */ SemVer ParseSemVer(bool required = true) { if (Peek()->token_type == TokenType::kVersion) { auto version = Match(TokenType::kVersion); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { this->diag_ctx.Emit(Diagnostic::Error(version->span) << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { this->diag_ctx.Emit(Diagnostic::Error(Peek()->span) << "expected text format semantic version, found a " << PrettyPrint(Peek())); this->diag_ctx.Emit(Diagnostic::Help(Peek()->span) << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } /*! \brief Parse zero or more Relay definitions. */ Definitions ParseDefinitions() { Definitions defs; while (true) { auto next = Peek(); switch (next->token_type) { case TokenType::kDefn: { Consume(TokenType::kDefn); auto global_tok = Match(TokenType::kGlobal); auto global_name = global_tok.ToString(); auto global = AddOrGet(&global_names, global_name); auto func = WithSpan([&]() { return ParseFunctionDef(); }); ICHECK(func->span.defined()) << "spans must be set in parser"; defs.funcs.push_back(GlobalFunc(global, func)); continue; } case TokenType::kTypeDef: { defs.types.push_back(ParseTypeDef()); continue; } case TokenType::kExtern: { Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { diag_ctx.Emit(Diagnostic::Error(next->span) << "an external type may not have any constructors"); } defs.types.push_back(type_def); } default: return defs; } } } /*! \brief Parse zero or more Relay type definitions. */ TypeData ParseTypeDef() { // Match the `type` keyword. Match(TokenType::kTypeDef); // Parse the type's identifier. auto type_tok = Match(TokenType::kIdentifier); auto type_id = type_tok.ToString(); auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle); Array generics; bool should_pop = false; if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); should_pop = true; generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } Array ctors; if (Peek()->token_type == TokenType::kLCurly) { // Parse the list of constructors. ctors = ParseSequence( TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { // First match the name of the constructor. auto ctor_tok = Match(TokenType::kIdentifier); auto ctor_name = ctor_tok.ToString(); Constructor ctor; // Match the optional field list. if (Peek()->token_type != TokenType::kOpenParen) { ctor = tvm::Constructor(ctor_name, {}, type_global); } else { auto arg_types = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { return ParseType(); }); ctor = tvm::Constructor(ctor_name, arg_types, type_global); } ICHECK(ctor.defined()); try { this->ctors.Add(ctor_name, ctor); } catch (const DuplicateKeyError& e) { this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span) << "a constructor with the name " << "`" << ctor_name << "` " << "was previously defined"); } return ctor; }); } // Now pop the type scope. if (should_pop) { PopTypeScopes(1); } return TypeData(type_global, generics, ctors); } std::string HackTokensAsString(int n) { std::stringstream key; n = std::min(static_cast(tokens.size() - pos), n); for (int i = 0; i < n; i++) { key << ToString(tokens.at(pos + i)->token_type); } return key.str(); } std::vector ParseOp() { std::vector matched; Peek(); for (int i = 4; i > 0; i--) { auto key = HackTokensAsString(i); auto it = this->op_table.this_is_a_hack.find(key); if (it != this->op_table.this_is_a_hack.end()) { pos = pos + i; matched.push_back(it->second); } } return matched; } /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { VLOG(9) << "Parser::ParseExpr"; return WithSpan([this] { std::vector exprs; while (true) { VLOG(9) << "Parser::ParseExpr: parsing a single expression"; auto next = Peek(); switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue case TokenType::kLCurly: { // NB: Might need to optimize to remove deep recursion. // Stack should only grow proportionally to the number of // nested scopes. // Parses `{` expression `}`. auto block = WithSpan([&]() { return Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { PushScope(); auto expr = ParseExpr(); PopScopes(1); return expr; }); }); exprs.push_back(block); break; } case TokenType::kFreeVar: { Consume(TokenType::kFreeVar); auto var_token = Match(TokenType::kLocal); Type type; if (WhenMatch(TokenType::kColon)) { type = ParseType(); } else { type = IncompleteType(); } BindFreeVar(var_token.ToString(), type); break; } // Parses `let ...`; case TokenType::kLet: exprs.push_back(ParseBindingExpr()); break; case TokenType::kMatch: case TokenType::kPartialMatch: { bool is_total = next->token_type == TokenType::kMatch; Consume(next->token_type); exprs.push_back(ParseMatch(is_total)); break; } // %x ... case TokenType::kGraph: if (Lookahead(2)->token_type == TokenType::kEqual) { exprs.push_back(ParseBindingExpr()); break; } // intentional fall through here. default: { exprs.push_back(ParseExprBinOp()); break; } } if (!WhenMatch(TokenType::kSemicolon)) { break; } } ICHECK_GE(exprs.size(), 1); if (exprs.size() == 1) { // ICHECK(exprs[0].defined() && exprs[0]->span.defined()) // << "parser must set expression spans.\n" // << exprs[0]; return exprs[0]; } else { auto body = exprs.back(); exprs.pop_back(); while (exprs.size()) { auto value = exprs.back(); ICHECK(value->span.defined()) << "parser must set expression spans."; exprs.pop_back(); body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span)); } ICHECK(body->span.defined()) << "parser must set expression spans."; return body; } }); } /*! \brief Parse a "binding expression"; an expression where * a graph or let variable is bound. * * In order to avoid stack overflow this is implemented in a special * iterative way to keep stack depth constant in a long chain of bindings. */ Expr ParseBindingExpr() { // We use a loop here so that the stack depth // does not grow linearly with a sequence of // graph or let bindings. // // Assuming we start at call depth k, we will // enter k + c call frames to parse the RHS // of the bindings where `c` is the depth // of recursion needed by RHS. // // If RHS is a call expresssion the c=1. // // Once we have parsed the RHS we will be // back at depth K, and will return to // this loop header to parse another // graph or let binding. // // This ensures for n sequential bindings // the call depth will be the same before // and after parsing the n bindings. VLOG(9) << "Parser::ParseBindingExpr"; std::vector> bindings; int scopes = 0; while (true) { auto next = Peek(); if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { Match(TokenType::kGraph); Match(TokenType::kEqual); auto val = this->ParseExprBinOp(); Match(TokenType::kSemicolon); AddGraphBinding(next, val); } else if (next->token_type == TokenType::kLet) { auto span = next->span; // Parse the 'let'. Consume(TokenType::kLet); // Parse the local '%'. auto local_tok = Match(TokenType::kLocal); auto string = local_tok.ToString(); // Parse the optional type annotation (':' ). Type type; if (WhenMatch(TokenType::kColon)) { type = ParseType(); } auto var = BindVar(string, type); // Parse the '='; Match(TokenType::kEqual); // Parse the body, and the ';'. auto val = this->ParseExprBinOp(); Consume(TokenType::kSemicolon); // Add the bindings to the local data structure. std::tuple tuple(var, val, span); bindings.push_back(tuple); scopes++; PushScope(); } else { // This is the only case we will increase the stack // depth. // // If we parse a program which is a sequence of N bindings // followed by a single body expression we will end up with // a call depth of 3, the first call to ParseExpr, then // ParseBindingExpr, then finally ParseExpr once more. auto body = this->ParseExpr(); // Remove the same number of scopes we added. PopScopes(scopes); if (bindings.size() == 0) { return body; } else { // We can now build the let binding up backwards. for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) { auto span = body->span.Merge(std::get<2>(*binding)); body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span); } return body; } } } } /*! Parse a function definition without a leading keyword or identifier. * * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { VLOG(9) << "Parser::ParseFunctionDef"; return WithSpan([&]() { PushScope(); PushTypeScope(); Array generics; if (Peek()->token_type == TokenType::kLSquare) { generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } Map raw_attrs; auto params = ParseSequence( TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { auto token = Match(TokenType::kLocal); auto string = token.ToString(); Type type; if (WhenMatch(TokenType::kColon)) { type = ParseType(); } return BindVar(string, type); }, [&] { auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; if (is_ident && next_is_equal) { raw_attrs = ParseAttrs(); return true; } return false; }); Type ret_type; if (WhenMatch(TokenType::kMinus)) { Match(TokenType::kRAngle); ret_type = ParseType(); } auto body = Block([&]() { return ParseExpr(); }); PopTypeScopes(1); PopScopes(1); // TODO(@jroesch): attributes should never be null, they should always be empty. if (raw_attrs.size()) { return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); } else { return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); } }); } /*! \brief Parse an if-expression. */ Expr ParseIf() { return WithSpan([&]() { VLOG(9) << "Parser::ParseIf"; Consume(TokenType::kIf); auto guard = WithSpan([&] { return Parens([&] { return ParseExpr(); }); }); auto true_branch = Block([&] { this->PushScope(); auto expr = ParseExpr(); this->PopScopes(1); return expr; }); Match(TokenType::kElse); auto false_branch = Block([&] { this->PushScope(); auto expr = ParseExpr(); this->PopScopes(1); return expr; }); return relay::If(guard, true_branch, false_branch); }); } /* This factors parsing a list of patterns for both tuples, and constructors. */ Array ParsePatternList() { return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&] { return ParsePattern(); }); } /*! \brief Parses a pattern for a match expression. * * A pattern is either a wildcard `_`, a local `%name`, * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn). * * This function recursively parses a pattern. */ Pattern ParsePattern() { VLOG(9) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { case TokenType::kUnderscore: { Match(TokenType::kUnderscore); return PatternWildcard(); } case TokenType::kLocal: { auto id = Match(TokenType::kLocal); Type type_annotation; if (WhenMatch(TokenType::kColon)) { type_annotation = ParseType(); } auto var = BindVar(id.ToString(), type_annotation); return PatternVar(var); } case TokenType::kIdentifier: { auto id = Match(TokenType::kIdentifier); auto ctor = ctors.Get(id.ToString()); if (!ctor) { diag_ctx.EmitFatal( // TODO(@jroesch): split into error and help // deal with multiple rendering Diagnostic::Error(id->span) << "undefined constructor name `" << id.ToString() << "`, perhaps you intended to write a" << "pattern variable, considering changing this to `%" << id.ToString() << "`"); } if (Peek()->token_type == TokenType::kOpenParen) { auto fields = ParsePatternList(); return PatternConstructor(ctor.value(), fields); } else { return PatternConstructor(ctor.value(), {}); } } default: return PatternTuple(ParsePatternList()); } } Clause ParseMatchArm() { PushScope(); auto pattern = ParsePattern(); Match(TokenType::kEqual); Consume(TokenType::kRAngle); auto expr = ParseExpr(); PopScopes(1); return Clause(pattern, expr); } Expr ParseMatch(bool is_total) { return WithSpan([&]() { Expr scrutinee = ParseAtomicExpr(); Array clauses = ParseSequence(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); }); return relay::Match(scrutinee, clauses, is_total); }); } Expr ParseExprBinOp() { VLOG(9) << "Parser::ParseExprBinOp"; return WithSpan([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall // through. std::vector exprs; Expr expr = WithSpan([this] { return ParseCallExpr(); }); exprs.push_back(expr); // Now we parse an optional op. std::vector ops; // We will now parse 0 or more operator occurrences. while (true) { auto opt_op = ParseOp(); // If we didn't parse one we done. if (opt_op.size() == 0) { break; } // Read the operation we parsed; auto op = opt_op[0]; Expr right = WithSpan([this] { return ParseCallExpr(); }); ICHECK(right->span.defined()); // If the operator stack is empty // we parse an operator and expression // and push them to stacks, then // continue. if (ops.size() == 0) { ops.push_back(op); exprs.push_back(right); continue; } if (op.precedence > ops.back().precedence || (op.precedence == ops.back().precedence && op.left_assoc == false)) { ops.push_back(op); exprs.push_back(right); continue; } while (ops.size() && (op.precedence < ops.back().precedence || (op.precedence == ops.back().precedence && op.left_assoc == true))) { Rule new_op = ops.back(); ops.pop_back(); Expr right = exprs.back(); exprs.pop_back(); Expr left = exprs.back(); exprs.pop_back(); ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; exprs.push_back( relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); } exprs.push_back(right); ops.push_back(op); } while (ops.size()) { Rule new_op = ops.back(); ops.pop_back(); Expr right = exprs.back(); exprs.pop_back(); Expr left = exprs.back(); exprs.pop_back(); ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; exprs.push_back( relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); } ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack."; ICHECK_EQ(exprs.size(), 1) << "Only a single expression should be left on the expression stack."; return exprs[0]; }); } ObjectRef ParseAttributeValue() { VLOG(9) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { case TokenType::kFloat: case TokenType::kInteger: case TokenType::kBoolean: case TokenType::kStringLiteral: return Match(next->token_type)->data; case TokenType::kMetaReference: return ParseMetaRef(); case TokenType::kLSquare: { return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); } case TokenType::kOpenParen: { // TODO(@jroesch: need to figure out bracket vs. sequence) // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, // TokenType::kCloseParen, // [&]() { return ParseAttributeValue(); }); return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, [&]() { return ParseAttributeValue(); }); } // TODO(@jroesch): not sure about this being the right way to handle nulls. case TokenType::kIdentifier: { if (auto text = next->data.as()) { std::string id = GetRef(text); if (id == "nullptr") { Match(TokenType::kIdentifier); return ObjectRef(); } if (id == "None") { Match(TokenType::kIdentifier); return Optional(); } } } default: return ParseAtomicExpr(); } } Map ParseAttrs() { VLOG(9) << "Parser::ParseAttrs"; Map kwargs; while (Peek()->token_type == TokenType::kIdentifier) { auto key = GetHierarchicalName(ParseHierarchicalName().data); Match(TokenType::kEqual); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. auto value = ParseAttributeValue(); // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text // format is bad. kwargs.Set(key, value); WhenMatch(TokenType::kComma); } VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { ICHECK(op.defined()) << "the operator must be defined"; VLOG(9) << "Parser::ParseCallArgs"; Attrs attrs; std::string op_key; bool is_op = false; if (auto op_node = op.as()) { is_op = true; op_key = op_node->attrs_type_key; } if (Peek()->token_type == TokenType::kOpenParen) { Array args = ParseSequence( TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&] { return ParseExpr(); }, [&] { auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; auto is_pretty_attrs = is_ident && next_is_equal; auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference; // TODO(@jroesch): might not handle trailing comma auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; auto is_meta_attrs = is_meta_next && last_meta; if (is_pretty_attrs || is_meta_attrs) { if (is_meta_attrs) { auto meta_ref = ParseMetaRef(); if (meta_ref.as()) { attrs = Downcast(meta_ref); } else { // Not awesome parsing code here. this->pos--; return false; } } else { auto raw_attrs = ParseAttrs(); if (is_op && op_key.size()) { auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } else if (raw_attrs.count("attrs_type_key")) { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } } else { this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) << "unable to determine the 'attrs_type_key' with which " "to represent the call attributes for this operator"); } } return true; } return false; }); if (!attrs.defined()) { if (is_op && op_key.size()) { auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {}); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } } // TODO(@jroesch): in a secondary pass adjust spans. return Expr(Call(op, args, attrs, {})); } else { return Expr(); } return Expr(); } Expr ParseCallExpr() { VLOG(9) << "Parser::ParseCallExpr"; return WithSpan([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. while (Peek()->token_type == TokenType::kOpenParen) { auto new_expr = ParseCallArgs(expr); if (new_expr.defined()) { expr = new_expr; } else { break; } } // We need a zero-arity case for constructors. if (auto ctor_node = expr.as()) { if (ctor_node->inputs.size() == 0) { return Expr(Call(expr, {})); } } return expr; }); } Expr GetOp(const std::string& op_name, const Span& span) { VLOG(9) << "op_name=" << op_name << " span=" << span; try { return Op::Get(op_name); } catch (const Error& e) { // we can relax this, but probably need to relax checks or return non-null here. this->diag_ctx.EmitFatal(Diagnostic::Error(span) << "operator `" << op_name << "` not found, perhaps you forgot to register it?"); return Expr(); } } Expr ParseAtomicExpr() { VLOG(9) << "Parser::ParseAtomicExpr"; Expr expr = WithSpan([this] { auto next = Peek(); switch (next->token_type) { case TokenType::kInteger: case TokenType::kFloat: { Consume(next->token_type); auto number = NumberToNDArray(next); Expr e = Constant(number, next->span); ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } case TokenType::kBoolean: { Consume(TokenType::kBoolean); int64_t value = Downcast(next->data); auto boolean = BooleanToNDarray(value); Expr e = Constant(boolean, next->span); ICHECK(e->span.defined()) << "constant spans must be defined"; return e; } // Parse a local of the form `%x`. case TokenType::kLocal: { Consume(TokenType::kLocal); return Expr(LookupLocal(next)); } // Parse a local of the form `@x`. case TokenType::kGlobal: { auto global_name = next.ToString(); Consume(TokenType::kGlobal); auto global = AddOrGet(&global_names, global_name); return Expr(global); } // Parse a local of the form `x`. // Right now we fail to parse `x.y`. case TokenType::kIdentifier: { auto ctor = ctors.Get(next.ToString()); if (ctor) { Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { auto spanned_idents = ParseHierarchicalName(); auto idents = spanned_idents.data; auto span = spanned_idents.span; return GetOp(GetHierarchicalName(idents), span); } } case TokenType::kGraph: { Consume(TokenType::kGraph); return LookupGraphBinding(next); } case TokenType::kMetaReference: { return Downcast(ParseMetaRef()); } case TokenType::kFn: { Consume(TokenType::kFn); Expr e = ParseFunctionDef(); ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; return e; } case TokenType::kIf: { Expr e = ParseIf(); return e; } case TokenType::kRef: { Consume(TokenType::kRef); Match(TokenType::kOpenParen); auto ref_value = ParseExpr(); Match(TokenType::kCloseParen); return static_cast(RefCreate(ref_value)); } case TokenType::kRefRead: { return WithSpan([&]() { Consume(TokenType::kRefRead); Match(TokenType::kOpenParen); auto ref = ParseExpr(); Match(TokenType::kCloseParen); return static_cast(RefRead(ref)); }); } case TokenType::kRefWrite: { return WithSpan([&]() { Consume(TokenType::kRefWrite); Match(TokenType::kOpenParen); auto ref = ParseExpr(); Match(TokenType::kComma); auto value = ParseExpr(); Match(TokenType::kCloseParen); return static_cast(RefWrite(ref, value)); }); } case TokenType::kOpenParen: { Span sp = next->span; Consume(TokenType::kOpenParen); // parse '(' ')' if (WhenMatch(TokenType::kCloseParen)) { return Expr(Tuple(Array())); } else { Expr subexpr = ParseExpr(); // parse '(' expr ')' if (WhenMatch(TokenType::kCloseParen)) { return subexpr; // parse '( expr ',' * ')' } else if (WhenMatch(TokenType::kComma)) { Array exprs = {subexpr}; while (true) { if (WhenMatch(TokenType::kCloseParen)) { break; } else { auto element = ParseExpr(); auto comma = Peek(); if (WhenMatch(TokenType::kComma)) { sp = sp.Merge(element->span.Merge(comma->span)); } else { sp = sp.Merge(element->span); } exprs.push_back(element); } } Expr tuple = Tuple(exprs, sp); ICHECK(tuple->span.defined()) << "tuple span should be defined"; return tuple; } } } default: { this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) << "expected an expression found " << Pretty(next->token_type)); return Expr(); } } }); if (WhenMatch(TokenType::kPeriod)) { auto token = Match(TokenType::kInteger); auto index = token.ToNumber(); auto span = token->span.Merge(expr->span); VLOG(9) << "Parser::ParseAtomicExpr: tuple get item"; return relay::TupleGetItem(expr, index, span); } else { return expr; } } /*! \brief Parse a hierarchical name. * * The tokenizer produces a token stream of . * and so on for names of the form `nn.conv2d`. * Currently we only use string names everywhere instead * of a notion of a hierarchical name. * * The below utility reassembles a token stream into a * single stream inserting the required periods needed * to look up registered names. */ Spanned> ParseHierarchicalName() { Array idents; Span span; while (Peek()->token_type == TokenType::kIdentifier) { auto token = Peek(); if (span.defined()) { span = span.Merge(token->span); } else { span = token->span; } auto name = token.ToString(); idents.push_back(name); Consume(TokenType::kIdentifier); // Keep parsing while we see a trailing period. if (Peek()->token_type == TokenType::kPeriod) { Consume(TokenType::kPeriod); continue; } else { // No more periods means we are done! break; } } return Spanned>(idents, span); } std::string GetHierarchicalName(Array idents) { ICHECK_NE(idents.size(), 0); std::stringstream hierarchical_name; int i = 0; int periods = idents.size() - 1; for (auto ident : idents) { hierarchical_name << ident; if (i < periods) { hierarchical_name << "."; i++; } } return hierarchical_name.str(); } /*! \brief Parse a shape. */ Array ParseShape() { auto dims = ParseSequence( TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { tvm::PrimExpr dim; if (Peek()->token_type == TokenType::kMetaReference) { dim = Downcast(ParseMetaRef()); } else if (WhenMatch(TokenType::kQuestion)) { dim = tvm::tir::Any(); } else { dim = Downcast(Match(TokenType::kInteger)->data); } return dim; }); return dims; } /*! \brief Parse a function type. */ Type ParseFunctionType() { auto ty_params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { return ParseType(); }); Match(TokenType::kMinus); Match(TokenType::kRAngle); auto ret_type = ParseType(); return relay::FuncType(ty_params, ret_type, {}, {}); } // Parses a user defined ADT or type variable. Type ParseNonPrimitiveType(const Token& tok) { return WithSpan([&]() { auto name = tok.ToString(); Type head_type = LookupTypeVar(tok); if (!head_type.defined()) { // head_type = type_names.Get(name); head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle); } if (!head_type.defined()) { diag_ctx.EmitFatal(Diagnostic::Error(tok->span) << "the type constructor `" << name << "` is undefined"); } Array arg_types; if (Peek()->token_type == TokenType::kLSquare) { arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseType(); }); } if (arg_types.size()) { return static_cast(TypeCall(head_type, arg_types)); } else { if (head_type.as()) { return static_cast(TypeCall(head_type, {})); } else { return static_cast(head_type); } } }); } /*! \brief Parses a TVM type. * * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type, * a scalar type or an incomplete type `_`. */ Type ParseType() { return WithSpan([&]() -> Type { auto tok = Peek(); if (tok->token_type == TokenType::kOpenParen) { auto tys = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { return ParseType(); }); return relay::TupleType(tys); } else if (WhenMatch(TokenType::kFn)) { return ParseFunctionType(); } else if (WhenMatch(TokenType::kIdentifier)) { auto id = tok.ToString(); if (id == "Tensor") { Match(TokenType::kLSquare); auto shape = ParseShape(); Match(TokenType::kComma); auto dtype_tok = Match(TokenType::kIdentifier); auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); Match(TokenType::kRSquare); return TensorType(shape, dtype); } else { auto ty = tok.ToString(); if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 || ty.find("bool", 0) == 0) { // Need to do better error handling here. auto dtype = DataType(String2DLDataType(tok.ToString())); return TensorType({}, dtype); } else { return ParseNonPrimitiveType(tok); } } } else if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span) << "failed to parse type found " << tok); return Type(); } }); } template R ConsumeWhitespace(std::function func) { auto old = this->ignore_whitespace; this->ignore_whitespace = true; while (tokens[pos]->token_type == TokenType::kWhitespace) { pos++; } auto res = func(); this->ignore_whitespace = old; return res; } Map> ParseMetadata() { if (Peek()->token_type == TokenType::kMetadata) { return Match(TokenType::kMetadata).ToMetadata(); } else { return Map>(); } } /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ void DisplayNextN(int n) { std::cout << "remaining tokens: " << std::endl; auto bound = std::min(pos + n, static_cast(tokens.size())); for (int i = 0; i < bound - pos; i++) { std::cout << tokens[pos + i] << std::endl; } } // A function for debugging the operator parser. void DebugStack(const std::vector& exprs, const std::vector& rules) { std::cout << "Expr Stack: "; for (auto expr : exprs) { std::cout << expr << ", "; } std::cout << std::endl; std::cout << "Op Stack: "; for (auto rule : rules) { std::cout << rule.op << ", "; } std::cout << std::endl; } }; Parser InitParser(const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); Source source(src_name, file_content); IRModule module; if (!init_module) { SourceMap source_map; module = IRModule({}, {}, {}, source_map); } else { module = init_module.value(); } module->source_map.Add(source); auto diag_ctx = DiagnosticContext::Default(module); auto tokens_and_table = Tokenize(diag_ctx, source); auto tokens = tokens_and_table.first; MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); // Merge any entries in init_meta_table into anything captured in the #[metadata] section // of the file_content. Metadata references within file_content must use indexes which account // for this ordering. for (const auto& pair : init_meta_table) { Array items; if (meta_data_table.count(pair.first)) { items = meta_data_table[pair.first]; } for (const auto& obj : pair.second) { items.push_back(obj); } meta_data_table.Set(pair.first, items); } return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); } IRModule ParseModule(const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { VLOG_CONTEXT << "ParseModule"; VLOG(9) << "parsing and type-checking " << file_name; auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; // NB(@jroesch): it is very important that we render any errors before we proceed // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); auto infer_type = tvm::relay::transform::InferType(); ICHECK(infer_type.defined()) << "The type inferencer must be non-null."; return infer_type(mod); } Expr ParseExpr(const std::string& file_name, const std::string& file_content) { VLOG(9) << "ParseExpr"; auto parser = InitParser(file_name, file_content, Optional(), MetaTable()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::kEndOfFile); // NB(@jroesch): it is very important that we render any errors before we proceed // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); return expr; } TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") .set_body_typed([](const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { return ParseModule(file_name, file_content, init_module, init_meta_table); }); TVM_REGISTER_GLOBAL("parser.ParseModule") .set_body_typed([](const std::string& file_name, const std::string& file_content) { return ParseModule(file_name, file_content); }); TVM_REGISTER_GLOBAL("parser.ParseExpr") .set_body_typed([](tvm::String file_name, tvm::String file_content) { return ParseExpr(file_name, file_content); }); /*! * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for * modules constructed programaticaly rather than textually. */ Pass AnnotateSpans() { auto pass_func = [](const IRModule& mod, const PassContext& ctx) { String text = AsText(mod, /*show_meta_data=*/true); VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; return ParseModule("GeneratedSource", text); }; return CreateModulePass(pass_func, 0, "AnnotateSpans", {}); } TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans); } // namespace parser } // namespace tvm