/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ /* * This file contains code from the Apache Spark project (original license below). * It contains modifications, which are licensed as above: */ /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.opensearch.flint.spark.sql import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, StructType} /** * Flint SQL parser that extends Spark SQL parser with Flint SQL statements. * * @param sparkParser * Spark SQL parser */ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface { /** Flint AST builder. */ private val flintAstBuilder = new FlintSparkSqlAstBuilder() override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => try { flintAstBuilder.visit(flintParser.singleStatement()) } catch { // Fall back to Spark parse plan logic if flint cannot parse case _: ParseException => sparkParser.parsePlan(sqlText) } } override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) override def parseTableIdentifier(sqlText: String): TableIdentifier = sparkParser.parseTableIdentifier(sqlText) override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = sparkParser.parseFunctionIdentifier(sqlText) override def parseMultipartIdentifier(sqlText: String): Seq[String] = sparkParser.parseMultipartIdentifier(sqlText) override def parseTableSchema(sqlText: String): StructType = sparkParser.parseTableSchema(sqlText) override def parseDataType(sqlText: String): DataType = sparkParser.parseDataType(sqlText) override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) // Starting from here is copied and modified from Spark 3.3.1 protected def parse[T](sqlText: String)(toResult: FlintSparkSqlExtensionsParser => T): T = { val lexer = new FlintSparkSqlExtensionsLexer( new UpperCaseCharStream(CharStreams.fromString(sqlText))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) val tokenStream = new CommonTokenStream(lexer) val parser = new FlintSparkSqlExtensionsParser(tokenStream) parser.addParseListener(FlintPostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) try { try { // first, try parsing with potentially faster SLL mode parser.getInterpreter.setPredictionMode(PredictionMode.SLL) toResult(parser) } catch { case e: ParseCancellationException => // if we fail, parse with LL mode tokenStream.seek(0) // rewind input stream parser.reset() // Try Again. parser.getInterpreter.setPredictionMode(PredictionMode.LL) toResult(parser) } } catch { case e: ParseException if e.command.isDefined => throw e case e: ParseException => throw e.withCommand(sqlText) case e: AnalysisException => val position = Origin(e.line, e.startPosition) throw new ParseException( Option(sqlText), e.message, position, position, e.errorClass, e.messageParameters) } } } class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { override def consume(): Unit = wrapped.consume() override def getSourceName: String = wrapped.getSourceName override def index(): Int = wrapped.index override def mark(): Int = wrapped.mark override def release(marker: Int): Unit = wrapped.release(marker) override def seek(where: Int): Unit = wrapped.seek(where) override def size(): Int = wrapped.size override def getText(interval: Interval): String = wrapped.getText(interval) override def LA(i: Int): Int = { val la = wrapped.LA(i) if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } } case object FlintPostProcessor extends FlintSparkSqlExtensionsBaseListener { /** Remove the back ticks from an Identifier. */ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { replaceTokenByIdentifier(ctx, 1) { token => // Remove the double back ticks in the string. token.setText(token.getText.replace("``", "`")) token } } /** Treat non-reserved keywords as Identifiers. */ override def exitNonReserved(ctx: NonReservedContext): Unit = { replaceTokenByIdentifier(ctx, 0)(identity) } private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)( f: CommonToken => CommonToken = identity): Unit = { val parent = ctx.getParent parent.removeLastChild() val token = ctx.getChild(0).getPayload.asInstanceOf[Token] val newToken = new CommonToken( new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), IDENTIFIER, token.getChannel, token.getStartIndex + stripMargins, token.getStopIndex - stripMargins) parent.addChild(new TerminalNodeImpl(f(newToken))) } }