From 2ee8e60b0fb7a289bdbc2ee54f517642b608780d Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Thu, 18 Jul 2024 15:30:39 +0200 Subject: [PATCH 1/6] Added variable declaration checks --- .../sql/catalyst/parser/AstBuilder.scala | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f5cf3e717a3ce..2606a7db014d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -133,12 +133,36 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { private def visitCompoundBodyImpl( ctx: CompoundBodyContext, - label: Option[String]): CompoundBody = { + label: Option[String], + allowPrefixDeclare: Boolean): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() ctx.compoundStatements.forEach(compoundStatement => { buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] }) + val compoundStatements = buff.toList + + if (allowPrefixDeclare) { + val declareAfterPrefix = compoundStatements.dropWhile( + statement => statement.isInstanceOf[SingleStatement] && + statement.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + .filter(_.isInstanceOf[SingleStatement]) + .exists(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + + if(declareAfterPrefix) { + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning() + } + + } else { + val declareExists = compoundStatements + .filter(_.isInstanceOf[SingleStatement]) + .exists(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + + if (declareExists) { + throw SqlScriptingErrors.variableDeclarationNotAllowedInScope() + } + } + CompoundBody(buff.toSeq, label) } @@ -161,11 +185,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val labelText = beginLabelCtx. map(_.multipartIdentifier().getText).getOrElse(java.util.UUID.randomUUID.toString). toLowerCase(Locale.ROOT) - visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText)) + visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), true) } override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = { - visitCompoundBodyImpl(ctx, None) + visitCompoundBodyImpl(ctx, None, false) } override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement = From dc136babeea16095e6b67c66ae6cc592f3c3c69b Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Thu, 18 Jul 2024 15:31:15 +0200 Subject: [PATCH 2/6] Added tests for variable declaration errors --- .../resources/error/error-conditions.json | 17 ++++++++++ .../main/resources/error/error-states.json | 6 ++++ .../spark/sql/errors/SqlScriptingErrors.scala | 16 ++++++++++ .../parser/SqlScriptingParserSuite.scala | 32 +++++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 84681ab8c2253..e16661f032b95 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2941,6 +2941,23 @@ ], "sqlState" : "22029" }, + "INVALID_VARIABLE_DECLARATION": { + "message" : [ + "Invalid variable declaration." + ], + "subClass" : { + "NOT_ALLOWED_IN_SCOPE" : { + "message": [ + "Variable declaration is not allowed in this scope." + ] + }, + "ONLY_AT_BEGINNING" : { + "message": [ + "Variable declaration is only possible at the beginning of the BEGIN END compound." + ] + } + } + }, "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : { "message" : [ "Variable type must be string type but got ." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index 0cd55bda7ba35..c5c55f11a6aa8 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4619,6 +4619,12 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0M": { + "description": "Invalid variable declaration.", + "origin": "Spark,", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index c1ce93e10553b..ffe0218e10970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -39,4 +39,20 @@ private[sql] object SqlScriptingErrors extends QueryErrorsBase { messageParameters = Map("endLabel" -> endLabel)) } + def variableDeclarationNotAllowedInScope(): Throwable = { + new SparkException( + errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", + cause = null, + messageParameters = Map() + ) + } + + def variableDeclarationOnlyAtBeginning(): Throwable = { + new SparkException( + errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + cause = null, + messageParameters = Map() + ) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index d4eb5fd747ac4..3d2d0eb79f6e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.plans.logical.CreateVariable class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { import CatalystSqlParser._ @@ -263,6 +264,37 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.label.nonEmpty) } + test("declare at the beginning") { + val sqlScriptText = + """ + |BEGIN + | DECLARE testVariable1 VARCHAR(50); + | DECLARE testVariable2 INTEGER; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 2) + assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) + assert(tree.collection.forall( + _.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable])) + } + + test("declare after beginning") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + | DECLARE testVariable1 INTEGER; + |END""".stripMargin + checkError( + exception = intercept[SparkException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + parameters = Map()) + } + + // TODO Add test for INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE exception + test("SET VAR statement test") { val sqlScriptText = """ From 14f175d7198cd92c57ae75d4d90b42323f61dbe2 Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Fri, 19 Jul 2024 10:52:21 +0200 Subject: [PATCH 3/6] Added variable name and line number to error message --- .../resources/error/error-conditions.json | 7 +-- .../sql/catalyst/parser/AstBuilder.scala | 45 +++++++++++-------- .../spark/sql/errors/SqlScriptingErrors.scala | 10 ++--- .../parser/SqlScriptingParserSuite.scala | 4 +- 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e16661f032b95..c2effcebbbf60 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2948,15 +2948,16 @@ "subClass" : { "NOT_ALLOWED_IN_SCOPE" : { "message": [ - "Variable declaration is not allowed in this scope." + "Variable was declared on line , which is not allowed in this scope." ] }, "ONLY_AT_BEGINNING" : { "message": [ - "Variable declaration is only possible at the beginning of the BEGIN END compound." + "Variable can only be declared at the beginning of the compound, but it was declared on line ." ] } - } + }, + "sqlState": "42K0M" }, "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : { "message" : [ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2606a7db014d0..9d4577784d47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -19,15 +19,12 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit - import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} - import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} - import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION @@ -47,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, Inte import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt import org.apache.spark.sql.internal.SQLConf @@ -134,7 +131,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { private def visitCompoundBodyImpl( ctx: CompoundBodyContext, label: Option[String], - allowPrefixDeclare: Boolean): CompoundBody = { + allowVarDeclare: Boolean): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() ctx.compoundStatements.forEach(compoundStatement => { buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] @@ -142,24 +139,34 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val compoundStatements = buff.toList - if (allowPrefixDeclare) { - val declareAfterPrefix = compoundStatements.dropWhile( - statement => statement.isInstanceOf[SingleStatement] && + if (allowVarDeclare) { + val declareAfterPrefix = compoundStatements + .dropWhile(statement => statement.isInstanceOf[SingleStatement] && statement.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) .filter(_.isInstanceOf[SingleStatement]) - .exists(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - - if(declareAfterPrefix) { - throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning() + .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + + declareAfterPrefix match { + case Some(SingleStatement(parsedPlan)) => + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + parsedPlan.asInstanceOf[CreateVariable].name. + asInstanceOf[UnresolvedIdentifier].nameParts.last, + parsedPlan.origin.line.get.toString) + case _ => } } else { - val declareExists = compoundStatements + val declare = compoundStatements .filter(_.isInstanceOf[SingleStatement]) - .exists(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - - if (declareExists) { - throw SqlScriptingErrors.variableDeclarationNotAllowedInScope() + .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + + declare match { + case Some(SingleStatement(parsedPlan)) => + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + parsedPlan.asInstanceOf[CreateVariable].name. + asInstanceOf[UnresolvedIdentifier].nameParts.last, + parsedPlan.origin.line.get.toString) + case _ => } } @@ -185,11 +192,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val labelText = beginLabelCtx. map(_.multipartIdentifier().getText).getOrElse(java.util.UUID.randomUUID.toString). toLowerCase(Locale.ROOT) - visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), true) + visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true) } override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = { - visitCompoundBodyImpl(ctx, None, false) + visitCompoundBodyImpl(ctx, None, allowVarDeclare = false) } override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index ffe0218e10970..8959911dbd8f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -39,20 +39,18 @@ private[sql] object SqlScriptingErrors extends QueryErrorsBase { messageParameters = Map("endLabel" -> endLabel)) } - def variableDeclarationNotAllowedInScope(): Throwable = { + def variableDeclarationNotAllowedInScope(varName: String, lineNumber: String): Throwable = { new SparkException( errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", cause = null, - messageParameters = Map() - ) + messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) } - def variableDeclarationOnlyAtBeginning(): Throwable = { + def variableDeclarationOnlyAtBeginning(varName: String, lineNumber: String): Throwable = { new SparkException( errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", cause = null, - messageParameters = Map() - ) + messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 3d2d0eb79f6e3..afb4c1425355c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -283,14 +283,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { """ |BEGIN | SELECT 1; - | DECLARE testVariable1 INTEGER; + | DECLARE testVariable INTEGER; |END""".stripMargin checkError( exception = intercept[SparkException] { parseScript(sqlScriptText) }, errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", - parameters = Map()) + parameters = Map("varName" -> "testVariable", "lineNumber" -> "4")) } // TODO Add test for INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE exception From 881bacbd6c3953cb017c1a492b22b542cea89a16 Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Fri, 19 Jul 2024 13:22:09 +0200 Subject: [PATCH 4/6] Fixed scalastyle --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9d4577784d47a..801a683648e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit + import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} + import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} + import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION @@ -44,7 +47,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, Inte import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt import org.apache.spark.sql.internal.SQLConf From a58dbc0ad6800a0bf4f81ec2f18665c54e3230ef Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Mon, 22 Jul 2024 09:47:54 +0200 Subject: [PATCH 5/6] Fix formatting --- .../main/resources/error/error-conditions.json | 8 ++++---- .../spark/sql/catalyst/parser/AstBuilder.scala | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index c2effcebbbf60..99599d4678c35 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2941,23 +2941,23 @@ ], "sqlState" : "22029" }, - "INVALID_VARIABLE_DECLARATION": { + "INVALID_VARIABLE_DECLARATION" : { "message" : [ "Invalid variable declaration." ], "subClass" : { "NOT_ALLOWED_IN_SCOPE" : { - "message": [ + "message" : [ "Variable was declared on line , which is not allowed in this scope." ] }, "ONLY_AT_BEGINNING" : { - "message": [ + "message" : [ "Variable can only be declared at the beginning of the compound, but it was declared on line ." ] } }, - "sqlState": "42K0M" + "sqlState" : "42K0M" }, "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : { "message" : [ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 801a683648e32..f0eb5403755ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -143,31 +143,33 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val compoundStatements = buff.toList if (allowVarDeclare) { - val declareAfterPrefix = compoundStatements + val declareVarStatement = compoundStatements .dropWhile(statement => statement.isInstanceOf[SingleStatement] && statement.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) .filter(_.isInstanceOf[SingleStatement]) .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - declareAfterPrefix match { + declareVarStatement match { case Some(SingleStatement(parsedPlan)) => throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - parsedPlan.asInstanceOf[CreateVariable].name. - asInstanceOf[UnresolvedIdentifier].nameParts.last, + parsedPlan.asInstanceOf[CreateVariable] + .name.asInstanceOf[UnresolvedIdentifier] + .nameParts.last, parsedPlan.origin.line.get.toString) case _ => } } else { - val declare = compoundStatements + val declareVarStatement = compoundStatements .filter(_.isInstanceOf[SingleStatement]) .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - declare match { + declareVarStatement match { case Some(SingleStatement(parsedPlan)) => throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - parsedPlan.asInstanceOf[CreateVariable].name. - asInstanceOf[UnresolvedIdentifier].nameParts.last, + parsedPlan.asInstanceOf[CreateVariable] + .name.asInstanceOf[UnresolvedIdentifier] + .nameParts.last, parsedPlan.origin.line.get.toString) case _ => } From 07df52cf25c1915a4e8ed18e2853cf865b0d2a4c Mon Sep 17 00:00:00 2001 From: Momcilo Mrkaic Date: Tue, 23 Jul 2024 17:45:18 +0200 Subject: [PATCH 6/6] Resolved comments --- .../sql/catalyst/parser/AstBuilder.scala | 54 +++++++++---------- .../parser/SqlScriptingParserSuite.scala | 2 +- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f0eb5403755ce..a046ededf964c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -48,8 +48,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, con import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} -import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LEGACY_BANG_EQUALS_NOT import org.apache.spark.sql.types._ @@ -62,7 +61,8 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { +class AstBuilder extends DataTypeAstBuilder with SQLConfHelper + with Logging with DataTypeErrorsBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import ParserUtils._ @@ -142,37 +142,31 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val compoundStatements = buff.toList - if (allowVarDeclare) { - val declareVarStatement = compoundStatements - .dropWhile(statement => statement.isInstanceOf[SingleStatement] && - statement.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - .filter(_.isInstanceOf[SingleStatement]) - .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) - - declareVarStatement match { - case Some(SingleStatement(parsedPlan)) => - throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - parsedPlan.asInstanceOf[CreateVariable] - .name.asInstanceOf[UnresolvedIdentifier] - .nameParts.last, - parsedPlan.origin.line.get.toString) - case _ => + val candidates = if (allowVarDeclare) { + compoundStatements.dropWhile { + case SingleStatement(_: CreateVariable) => true + case _ => false } - } else { - val declareVarStatement = compoundStatements - .filter(_.isInstanceOf[SingleStatement]) - .find(_.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable]) + compoundStatements + } + + val declareVarStatement = candidates.collectFirst { + case SingleStatement(c: CreateVariable) => c + } - declareVarStatement match { - case Some(SingleStatement(parsedPlan)) => + declareVarStatement match { + case Some(c: CreateVariable) => + if (allowVarDeclare) { throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - parsedPlan.asInstanceOf[CreateVariable] - .name.asInstanceOf[UnresolvedIdentifier] - .nameParts.last, - parsedPlan.origin.line.get.toString) - case _ => - } + toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), + c.origin.line.get.toString) + } else { + throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( + toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), + c.origin.line.get.toString) + } + case _ => } CompoundBody(buff.toSeq, label) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index afb4c1425355c..47d7f76742639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -290,7 +290,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { parseScript(sqlScriptText) }, errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", - parameters = Map("varName" -> "testVariable", "lineNumber" -> "4")) + parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) } // TODO Add test for INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE exception