diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 84681ab8c2253..99599d4678c35 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2941,6 +2941,24 @@ ], "sqlState" : "22029" }, + "INVALID_VARIABLE_DECLARATION" : { + "message" : [ + "Invalid variable declaration." + ], + "subClass" : { + "NOT_ALLOWED_IN_SCOPE" : { + "message" : [ + "Variable was declared on line , which is not allowed in this scope." + ] + }, + "ONLY_AT_BEGINNING" : { + "message" : [ + "Variable can only be declared at the beginning of the compound, but it was declared on line ." + ] + } + }, + "sqlState" : "42K0M" + }, "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : { "message" : [ "Variable type must be string type but got ." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index 0cd55bda7ba35..c5c55f11a6aa8 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4619,6 +4619,12 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0M": { + "description": "Invalid variable declaration.", + "origin": "Spark,", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f5cf3e717a3ce..a046ededf964c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -48,8 +48,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, con import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} -import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LEGACY_BANG_EQUALS_NOT import org.apache.spark.sql.types._ @@ -62,7 +61,8 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { +class AstBuilder extends DataTypeAstBuilder with SQLConfHelper + with Logging with DataTypeErrorsBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import ParserUtils._ @@ -133,12 +133,42 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { private def visitCompoundBodyImpl( ctx: CompoundBodyContext, - label: Option[String]): CompoundBody = { + label: Option[String], + allowVarDeclare: Boolean): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() ctx.compoundStatements.forEach(compoundStatement => { buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] }) + val compoundStatements = buff.toList + + val candidates = if (allowVarDeclare) { + compoundStatements.dropWhile { + case SingleStatement(_: CreateVariable) => true + case _ => false + } + } else { + compoundStatements + } + + val declareVarStatement = candidates.collectFirst { + case SingleStatement(c: CreateVariable) => c + } + + declareVarStatement match { + case Some(c: CreateVariable) => + if (allowVarDeclare) { + throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( + toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), + c.origin.line.get.toString) + } else { + throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( + toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), + c.origin.line.get.toString) + } + case _ => + } + CompoundBody(buff.toSeq, label) } @@ -161,11 +191,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val labelText = beginLabelCtx. map(_.multipartIdentifier().getText).getOrElse(java.util.UUID.randomUUID.toString). toLowerCase(Locale.ROOT) - visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText)) + visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true) } override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = { - visitCompoundBodyImpl(ctx, None) + visitCompoundBodyImpl(ctx, None, allowVarDeclare = false) } override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index c1ce93e10553b..8959911dbd8f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -39,4 +39,18 @@ private[sql] object SqlScriptingErrors extends QueryErrorsBase { messageParameters = Map("endLabel" -> endLabel)) } + def variableDeclarationNotAllowedInScope(varName: String, lineNumber: String): Throwable = { + new SparkException( + errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", + cause = null, + messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + } + + def variableDeclarationOnlyAtBeginning(varName: String, lineNumber: String): Throwable = { + new SparkException( + errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + cause = null, + messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index d4eb5fd747ac4..47d7f76742639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.plans.logical.CreateVariable class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { import CatalystSqlParser._ @@ -263,6 +264,37 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.label.nonEmpty) } + test("declare at the beginning") { + val sqlScriptText = + """ + |BEGIN + | DECLARE testVariable1 VARCHAR(50); + | DECLARE testVariable2 INTEGER; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 2) + assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) + assert(tree.collection.forall( + _.asInstanceOf[SingleStatement].parsedPlan.isInstanceOf[CreateVariable])) + } + + test("declare after beginning") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + | DECLARE testVariable INTEGER; + |END""".stripMargin + checkError( + exception = intercept[SparkException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) + } + + // TODO Add test for INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE exception + test("SET VAR statement test") { val sqlScriptText = """