From 0dea6c6a3a0ee857fe9db2df4efcf2d2e2ea5f00 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 18 Jun 2024 16:53:32 +0200 Subject: [PATCH 01/99] Initial changes for SQL scripting interpreter --- .../scripting/SqlScriptingExecutionNode.scala | 94 +++++++++++ .../scripting/SqlScriptingInterpreter.scala | 59 +++++++ .../SqlScriptingIntegrationSuite.scala | 153 ++++++++++++++++++ .../SqlScriptingInterpreterSuite.scala | 81 ++++++++++ 4 files changed, 387 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala new file mode 100644 index 0000000000000..14a56dcaa462d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} + +sealed trait CompoundStatementExec extends Logging { + val isInternal: Boolean = false + def reset(): Unit +} + +trait LeafStatementExec extends CompoundStatementExec + +trait NonLeafStatementExec extends CompoundStatementExec with Iterator[CompoundStatementExec] + +class SingleStatementExec( + var parsedPlan: LogicalPlan, + override val origin: Origin, + override val isInternal : Boolean) + extends LeafStatementExec + with WithOrigin { + + var consumed = false + + override def reset(): Unit = consumed = false + + def getText(sqlScriptText: String): String = { + if (origin.startIndex.isEmpty || origin.stopIndex.isEmpty) { + return null + } + sqlScriptText.substring(origin.startIndex.get, origin.stopIndex.get + 1) + } +} + +abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) + extends NonLeafStatementExec { + + var localIterator = collection.iterator + var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + + override def hasNext: Boolean = { + val childHasNext = curr match { + case Some(body: NonLeafStatementExec) => body.hasNext + case Some(_: LeafStatementExec) => true + case None => false + case _ => throw new IllegalStateException("Unknown statement type") + } + localIterator.hasNext || childHasNext + } + + override def next(): CompoundStatementExec = { + curr match { + case None => throw new IllegalStateException("No more elements") + case Some(statement: LeafStatementExec) => + if (localIterator.hasNext) curr = Some(localIterator.next()) + else curr = None + statement + case Some(body: NonLeafStatementExec) => + if (body.hasNext) { + body.next() + } else { + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + next() + } + case _ => throw new IllegalStateException("Unknown statement type") + } + } + + override def reset(): Unit = { + collection.foreach(_.reset()) + localIterator = collection.iterator + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + } +} + +class CompoundBodyExec(statements: Seq[CompoundStatementExec]) + extends CompoundNestedStatementIteratorExec(statements) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala new file mode 100644 index 0000000000000..021f5146e5d5e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.Origin + +trait ProceduralLanguageInterpreter { + def buildExecutionPlan(compound: CompoundBody) : Iterator[CompoundStatementExec] +} + +case class SqlScriptingInterpreter() extends ProceduralLanguageInterpreter { + override def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { + transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec] + } + + private def getDeclareVarNameFromPlan(plan: LogicalPlan): Option[UnresolvedIdentifier] = + plan match { + case CreateVariable(name: UnresolvedIdentifier, _, _) => Some(name) + case _ => None + } + + private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec = + node match { + case body: CompoundBody => + val variables = body.collection.flatMap { + case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) + case _ => None + } + val dropVariables = variables + .map(varName => DropVariable(varName, ifExists = true)) + .map(new SingleStatementExec(_, Origin(), isInternal = true)) + .reverse + new CompoundBodyExec( + body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables) + case sparkStatement: SingleStatement => + new SingleStatementExec( + sparkStatement.parsedPlan, + sparkStatement.origin, + isInternal = false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala new file mode 100644 index 0000000000000..4d7af0f271c9d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.sql.{Dataset, QueryTest, Row} +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.test.SharedSparkSession + +class SqlScriptingIntegrationSuite extends QueryTest with SharedSparkSession { + // Helpers + private def verifySqlScriptResult( + sqlText: String, expected: Seq[Seq[Row]], printResult: Boolean = false): Unit = { + val interpreter = SqlScriptingInterpreter() + val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) + val executionPlan = interpreter.buildExecutionPlan(compoundBody) + val result = executionPlan.flatMap { + case statement: SingleStatementExec => + if (printResult) { + // scalastyle:off println + println("Executing: " + statement.getText(sqlText)) + // scalastyle:on println + } + + if (statement.consumed) { + None + } else { + Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } + case _ => None + }.toArray + + assert(result.length == expected.length) + result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} + } + + // Tests + test("select 1") { + verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) + } + + test("multi statement - simple") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT a, b FROM t WHERE a = 12; + |SELECT a FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // select with filter + Seq(Row(1)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("multi statement - count") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT + | CASE WHEN COUNT(*) > 10 THEN true + | ELSE false + | END AS MoreThanTen + |FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert #1 + Seq.empty[Row], // insert #2 + Seq(Row(false)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("session vars - set and read") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(2)), // select + Seq.empty[Row] // drop var + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read scoped") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE var = 1; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 2; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 3; + | SET VAR var = var + 1; + | SELECT var; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq(Row(2)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(4)), // select + Seq.empty[Row], // drop var + ) + verifySqlScriptResult(sqlScript, expected) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala new file mode 100644 index 0000000000000..de94519861d7c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.catalyst.trees.Origin + +class SqlScriptingInterpreterSuite extends SparkFunSuite { + // Helpers + case class TestLeafStatement(testVal: String) extends LeafStatementExec { + override def reset(): Unit = () + } + + case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) + extends CompoundNestedStatementIteratorExec(statements) + + case class TestBody(statements: Seq[CompoundStatementExec]) + extends CompoundBodyExec(statements) + + case class TestSparkStatementWithPlan(testVal: String) + extends SingleStatementExec( + parsedPlan = Project(Seq(Alias(Literal(testVal), "condition")()), OneRowRelation()), + Origin(startIndex = Some(0), stopIndex = Some(testVal.length)), + isInternal = false) + + // Tests + test("test body - single statement") { + val iter = TestNestedStatementIterator(Seq(TestLeafStatement("one"))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === List("one")) + } + + test("test body - no nesting") { + val iter = TestNestedStatementIterator( + Seq( + TestLeafStatement("one"), + TestLeafStatement("two"), + TestLeafStatement("three"))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === Seq("one", "two", "three")) + } + + test("test body - nesting") { + val iter = TestNestedStatementIterator( + Seq( + TestNestedStatementIterator(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), + TestLeafStatement("three"), + TestNestedStatementIterator(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === Seq("one", "two", "three", "four", "five")) + } +} \ No newline at end of file From d7b195047cb3b6bcd6ec529a70aaf14bb0b3a3d1 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 18 Jun 2024 18:06:38 +0200 Subject: [PATCH 02/99] Add comments --- .../scripting/SqlScriptingExecutionNode.scala | 40 +++++++++++++++++++ .../scripting/SqlScriptingInterpreter.scala | 23 +++++++++++ 2 files changed, 63 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 14a56dcaa462d..cb4aaf817b83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -21,15 +21,40 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} +/** + * Trait for all SQL scripting execution nodes used during interpretation phase. + */ sealed trait CompoundStatementExec extends Logging { + /** + * Whether the statement originates from the script or it is created during the interpretation. + * Example: DropVariable statements are automatically created at the end of each compound. + */ val isInternal: Boolean = false + + /** + * Reset execution of the current node. + */ def reset(): Unit } +/** + * Leaf node in the execution tree. + */ trait LeafStatementExec extends CompoundStatementExec +/** + * Non-leaf node in the execution tree. + * It is an iterator over executable child nodes. + */ trait NonLeafStatementExec extends CompoundStatementExec with Iterator[CompoundStatementExec] +/** + * Executable node for SingleStatement. + * @param parsedPlan Logical plan of the parsed statement. + * @param origin Origin descriptor for the statement. + * @param isInternal Whether the statement originates from the script + * or it is created during the interpretation. + */ class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, @@ -37,10 +62,16 @@ class SingleStatementExec( extends LeafStatementExec with WithOrigin { + /** + * Whether this statement had to be executed during the interpretation phase. + * Example: Statements in conditions of If/Else, While, etc. + */ var consumed = false + /** @inheritdoc */ override def reset(): Unit = consumed = false + /** Get the SQL query text corresponding to this statement. */ def getText(sqlScriptText: String): String = { if (origin.startIndex.isEmpty || origin.stopIndex.isEmpty) { return null @@ -49,6 +80,11 @@ class SingleStatementExec( } } +/** + * Abstract class for all statements that contain nested statements. + * Implements recursive iterator logic over all child execution nodes. + * @param collection Collection of child execution nodes. + */ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) extends NonLeafStatementExec { @@ -90,5 +126,9 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState } } +/** + * Executable node for CompoundBody. + * @param statements Executable nodes for nested statements within the CompoundBody. + */ class CompoundBodyExec(statements: Seq[CompoundStatementExec]) extends CompoundNestedStatementIteratorExec(statements) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 021f5146e5d5e..af9d9f0444cd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -22,21 +22,44 @@ import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin +/** + * Trait for SQL Scripting interpreters. + */ trait ProceduralLanguageInterpreter { + /** + * Build execution plan and return statements that need to be executed, + * wrapped in the execution node. + * @param compound CompoundBody for which to build the plan. + * @return Iterator through collection of statements to be executed. + */ def buildExecutionPlan(compound: CompoundBody) : Iterator[CompoundStatementExec] } +/** + * Concrete implementation of the interpreter for SQL scripting. + */ case class SqlScriptingInterpreter() extends ProceduralLanguageInterpreter { + /** @inheritdoc */ override def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec] } + /** + * Fetch the name of the Create Variable plan. + * @param plan Plan to fetch the name from. + * @return Name of the variable. + */ private def getDeclareVarNameFromPlan(plan: LogicalPlan): Option[UnresolvedIdentifier] = plan match { case CreateVariable(name: UnresolvedIdentifier, _, _) => Some(name) case _ => None } + /** + * Transform the parsed tree to the executable node. + * @param node Root node of the parsed tree. + * @return Executable statement. + */ private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec = node match { case body: CompoundBody => From 08727696e37ff34e6e5438aa1388f70ecfe06146 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 19 Jun 2024 10:31:07 +0200 Subject: [PATCH 03/99] Whitespace changes --- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 2 +- .../spark/sql/scripting/SqlScriptingIntegrationSuite.scala | 2 +- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index cb4aaf817b83e..2420a60983927 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} */ sealed trait CompoundStatementExec extends Logging { /** - * Whether the statement originates from the script or it is created during the interpretation. + * Whether the statement originates from the SQL script or it is created during the interpretation. * Example: DropVariable statements are automatically created at the end of each compound. */ val isInternal: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala index 4d7af0f271c9d..344d546dc9e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala @@ -150,4 +150,4 @@ class SqlScriptingIntegrationSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScript, expected) } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index de94519861d7c..0ee8f01ce6d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -78,4 +78,4 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite { assert(statements === Seq("one", "two", "three", "four", "five")) } -} \ No newline at end of file +} From dd20a7e546c452b4a1e59fdad5333b98404fe60d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 19 Jun 2024 14:00:05 +0200 Subject: [PATCH 04/99] Addressing comments vol1 --- .../scripting/SqlScriptingExecutionNode.scala | 7 +- .../scripting/SqlScriptingInterpreter.scala | 1 + .../SqlScriptingExecutionNodeSuite.scala | 86 ++++++++ .../SqlScriptingIntegrationSuite.scala | 153 -------------- .../SqlScriptingInterpreterSuite.scala | 186 +++++++++++++----- 5 files changed, 229 insertions(+), 204 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 2420a60983927..3f6566934339d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} */ sealed trait CompoundStatementExec extends Logging { /** - * Whether the statement originates from the SQL script or it is created during the interpretation. + * Whether the statement originates from the SQL script or is created during the interpretation. * Example: DropVariable statements are automatically created at the end of each compound. */ val isInternal: Boolean = false @@ -52,8 +52,9 @@ trait NonLeafStatementExec extends CompoundStatementExec with Iterator[CompoundS * Executable node for SingleStatement. * @param parsedPlan Logical plan of the parsed statement. * @param origin Origin descriptor for the statement. - * @param isInternal Whether the statement originates from the script - * or it is created during the interpretation. + * @param isInternal Whether the statement originates from the SQL script or it is created during + * the interpretation. Example: DropVariable statements are automatically created + * at the end of each compound. */ class SingleStatementExec( var parsedPlan: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index af9d9f0444cd3..ef3e7fa362c38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -63,6 +63,7 @@ case class SqlScriptingInterpreter() extends ProceduralLanguageInterpreter { private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec = node match { case body: CompoundBody => + // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. val variables = body.collection.flatMap { case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) case _ => None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala new file mode 100644 index 0000000000000..17e00d07d496b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.catalyst.trees.Origin + +/** + * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. + * Execution nodes are constructed manually and iterated through. + * It is then checked if the leaf statements have been iterated in the expected order. + */ +class SqlScriptingExecutionNodeSuite extends SparkFunSuite { + // Helpers + case class TestLeafStatement(testVal: String) extends LeafStatementExec { + override def reset(): Unit = () + } + + case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) + extends CompoundNestedStatementIteratorExec(statements) + + case class TestBody(statements: Seq[CompoundStatementExec]) + extends CompoundBodyExec(statements) + + case class TestSparkStatementWithPlan(testVal: String) + extends SingleStatementExec( + parsedPlan = Project(Seq(Alias(Literal(testVal), "condition")()), OneRowRelation()), + Origin(startIndex = Some(0), stopIndex = Some(testVal.length)), + isInternal = false) + + // Tests + test("test body - single statement") { + val iter = TestNestedStatementIterator(Seq(TestLeafStatement("one"))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === List("one")) + } + + test("test body - no nesting") { + val iter = TestNestedStatementIterator( + Seq( + TestLeafStatement("one"), + TestLeafStatement("two"), + TestLeafStatement("three"))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === Seq("one", "two", "three")) + } + + test("test body - nesting") { + val iter = TestNestedStatementIterator( + Seq( + TestNestedStatementIterator(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), + TestLeafStatement("three"), + TestNestedStatementIterator(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) + val statements = iter.map { + case TestLeafStatement(v) => v + case _ => fail("Unexpected statement type") + }.toList + + assert(statements === Seq("one", "two", "three", "four", "five")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala deleted file mode 100644 index 344d546dc9e56..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.scripting - -import org.apache.spark.sql.{Dataset, QueryTest, Row} -import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.test.SharedSparkSession - -class SqlScriptingIntegrationSuite extends QueryTest with SharedSparkSession { - // Helpers - private def verifySqlScriptResult( - sqlText: String, expected: Seq[Seq[Row]], printResult: Boolean = false): Unit = { - val interpreter = SqlScriptingInterpreter() - val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) - val executionPlan = interpreter.buildExecutionPlan(compoundBody) - val result = executionPlan.flatMap { - case statement: SingleStatementExec => - if (printResult) { - // scalastyle:off println - println("Executing: " + statement.getText(sqlText)) - // scalastyle:on println - } - - if (statement.consumed) { - None - } else { - Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) - } - case _ => None - }.toArray - - assert(result.length == expected.length) - result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} - } - - // Tests - test("select 1") { - verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) - } - - test("multi statement - simple") { - withTable("t") { - val sqlScript = - """ - |BEGIN - |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; - |INSERT INTO t VALUES (1, 'a', 1.0); - |SELECT a, b FROM t WHERE a = 12; - |SELECT a FROM t; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // select with filter - Seq(Row(1)) - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("multi statement - count") { - withTable("t") { - val sqlScript = - """ - |BEGIN - |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; - |INSERT INTO t VALUES (1, 'a', 1.0); - |INSERT INTO t VALUES (1, 'a', 1.0); - |SELECT - | CASE WHEN COUNT(*) > 10 THEN true - | ELSE false - | END AS MoreThanTen - |FROM t; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert #1 - Seq.empty[Row], // insert #2 - Seq(Row(false)) - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("session vars - set and read") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var - ) - verifySqlScriptResult(sqlScript, expected) - } - - test("session vars - set and read scoped") { - val sqlScript = - """ - |BEGIN - | BEGIN - | DECLARE var = 1; - | SELECT var; - | END; - | BEGIN - | DECLARE var = 2; - | SELECT var; - | END; - | BEGIN - | DECLARE var = 3; - | SET VAR var = var + 1; - | SELECT var; - | END; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(4)), // select - Seq.empty[Row], // drop var - ) - verifySqlScriptResult(sqlScript, expected) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 0ee8f01ce6d6c..8b29644440cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,65 +17,155 @@ package org.apache.spark.sql.scripting -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, Row} +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.test.SharedSparkSession -class SqlScriptingInterpreterSuite extends SparkFunSuite { +/** + * SQL Scripting interpreter tests. + * Output from the parser is provided to the interpreter. + * Output from the interpreter (iterator over executable statements) is then checked - statements + * are executed and output DataFrames are compared with expected outputs. + */ +class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Helpers - case class TestLeafStatement(testVal: String) extends LeafStatementExec { - override def reset(): Unit = () - } - - case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) - - case class TestBody(statements: Seq[CompoundStatementExec]) - extends CompoundBodyExec(statements) + private def verifySqlScriptResult( + sqlText: String, expected: Seq[Seq[Row]]): Unit = { + val interpreter = SqlScriptingInterpreter() + val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) + val executionPlan = interpreter.buildExecutionPlan(compoundBody) + val result = executionPlan.flatMap { + case statement: SingleStatementExec => + if (statement.consumed) { + None + } else { + Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } + case _ => None + }.toArray - case class TestSparkStatementWithPlan(testVal: String) - extends SingleStatementExec( - parsedPlan = Project(Seq(Alias(Literal(testVal), "condition")()), OneRowRelation()), - Origin(startIndex = Some(0), stopIndex = Some(testVal.length)), - isInternal = false) + assert(result.length == expected.length) + result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} + } // Tests - test("test body - single statement") { - val iter = TestNestedStatementIterator(Seq(TestLeafStatement("one"))) - val statements = iter.map { - case TestLeafStatement(v) => v - case _ => fail("Unexpected statement type") - }.toList + test("select 1") { + verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) + } - assert(statements === List("one")) + test("multi statement - simple") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT a, b FROM t WHERE a = 12; + |SELECT a FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // select with filter + Seq(Row(1)) + ) + verifySqlScriptResult(sqlScript, expected) + } } - test("test body - no nesting") { - val iter = TestNestedStatementIterator( - Seq( - TestLeafStatement("one"), - TestLeafStatement("two"), - TestLeafStatement("three"))) - val statements = iter.map { - case TestLeafStatement(v) => v - case _ => fail("Unexpected statement type") - }.toList + test("multi statement - count") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT + | CASE WHEN COUNT(*) > 10 THEN true + | ELSE false + | END AS MoreThanTen + |FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert #1 + Seq.empty[Row], // insert #2 + Seq(Row(false)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } - assert(statements === Seq("one", "two", "three")) + test("session vars - set and read") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(2)), // select + Seq.empty[Row] // drop var + ) + verifySqlScriptResult(sqlScript, expected) } - test("test body - nesting") { - val iter = TestNestedStatementIterator( - Seq( - TestNestedStatementIterator(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), - TestLeafStatement("three"), - TestNestedStatementIterator(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) - val statements = iter.map { - case TestLeafStatement(v) => v - case _ => fail("Unexpected statement type") - }.toList + test("session vars - set and read scoped") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE var = 1; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 2; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 3; + | SET VAR var = var + 1; + | SELECT var; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq(Row(2)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(4)), // select + Seq.empty[Row], // drop var + ) + verifySqlScriptResult(sqlScript, expected) + } - assert(statements === Seq("one", "two", "three", "four", "five")) + test("session vars - var out of scope") { + val e = intercept[AnalysisException] { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE testVarName = 1; + | SELECT testVarName; + | END; + | SELECT testVarName; + |END + |""".stripMargin + verifySqlScriptResult(sqlScript, Seq.empty) + } + assert(e.getErrorClass === "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + assert{e.getMessage.contains("testVarName")} } } From 0117228fe088e12a99d11e276fc4f7d7a3faeab4 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 21 Jun 2024 12:28:49 +0200 Subject: [PATCH 05/99] Addressing comments vol2 --- .../scripting/SqlScriptingExecutionNode.scala | 10 +++++++--- .../SqlScriptingInterpreterSuite.scala | 20 +++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 3f6566934339d..cd2352bc1fdd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -97,14 +98,16 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState case Some(body: NonLeafStatementExec) => body.hasNext case Some(_: LeafStatementExec) => true case None => false - case _ => throw new IllegalStateException("Unknown statement type") + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") } localIterator.hasNext || childHasNext } override def next(): CompoundStatementExec = { curr match { - case None => throw new IllegalStateException("No more elements") + case None => throw SparkException.internalError( + "No more elements to iterate through in the current SQL compound statement.") case Some(statement: LeafStatementExec) => if (localIterator.hasNext) curr = Some(localIterator.next()) else curr = None @@ -116,7 +119,8 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState curr = if (localIterator.hasNext) Some(localIterator.next()) else None next() } - case _ => throw new IllegalStateException("Unknown statement type") + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8b29644440cc8..d2d93d3684e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -168,4 +168,24 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { assert(e.getErrorClass === "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") assert{e.getMessage.contains("testVarName")} } + + test("session vars - drop var statement") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |DROP TEMPORARY VARIABLE var; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(2)), // select + Seq.empty[Row], // drop var - explicit + Seq.empty[Row] // drop var - implicit + ) + verifySqlScriptResult(sqlScript, expected) + } } From bd2a90707d7179645f39c157a5857ab19f6da469 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 00:11:39 +0200 Subject: [PATCH 06/99] Addressing comments vol3 --- .../scripting/SqlScriptingExecutionNode.scala | 11 ++++------- .../sql/scripting/SqlScriptingInterpreter.scala | 17 +++++------------ 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index cd2352bc1fdd9..a210d36f34551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -74,11 +74,9 @@ class SingleStatementExec( override def reset(): Unit = consumed = false /** Get the SQL query text corresponding to this statement. */ - def getText(sqlScriptText: String): String = { - if (origin.startIndex.isEmpty || origin.stopIndex.isEmpty) { - return null - } - sqlScriptText.substring(origin.startIndex.get, origin.stopIndex.get + 1) + def getText: String = { + assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) + origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) } } @@ -109,8 +107,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") case Some(statement: LeafStatementExec) => - if (localIterator.hasNext) curr = Some(localIterator.next()) - else curr = None + curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement case Some(body: NonLeafStatementExec) => if (body.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index ef3e7fa362c38..4fbbe73d85642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -23,24 +23,17 @@ import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable import org.apache.spark.sql.catalyst.trees.Origin /** - * Trait for SQL Scripting interpreters. + * SQL scripting interpreter - builds SQL script execution plan. */ -trait ProceduralLanguageInterpreter { +case class SqlScriptingInterpreter() { /** * Build execution plan and return statements that need to be executed, - * wrapped in the execution node. + * wrapped in the execution node. + * * @param compound CompoundBody for which to build the plan. * @return Iterator through collection of statements to be executed. */ - def buildExecutionPlan(compound: CompoundBody) : Iterator[CompoundStatementExec] -} - -/** - * Concrete implementation of the interpreter for SQL scripting. - */ -case class SqlScriptingInterpreter() extends ProceduralLanguageInterpreter { - /** @inheritdoc */ - override def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { + def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec] } From 746e04bf7f157c12065f05c409ae1acba7fe6a9d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 00:12:40 +0200 Subject: [PATCH 07/99] Improve getText in SingleStatement logical operator --- .../parser/SqlScriptingLogicalOperators.scala | 9 ++++----- .../catalyst/parser/SqlScriptingParserSuite.scala | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 816ef82a3d8e6..b458aeaa9c439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -37,11 +37,10 @@ case class SingleStatement(parsedPlan: LogicalPlan) override val origin: Origin = CurrentOrigin.get - def getText(sqlScriptText: String): String = { - if (origin.startIndex.isEmpty || origin.stopIndex.isEmpty) { - return null - } - sqlScriptText.substring(origin.startIndex.get, origin.stopIndex.get + 1) + /** Get the SQL query text corresponding to this statement. */ + def getText: String = { + assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) + origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 657e4b2232ee9..aa72e409f528b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -29,7 +29,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[SingleStatement]) val sparkStatement = tree.collection.head.asInstanceOf[SingleStatement] - assert(sparkStatement.getText(sqlScriptText) == "SELECT 1;") + assert(sparkStatement.getText == "SELECT 1;") } test("single select without ;") { @@ -38,7 +38,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[SingleStatement]) val sparkStatement = tree.collection.head.asInstanceOf[SingleStatement] - assert(sparkStatement.getText(sqlScriptText) == "SELECT 1") + assert(sparkStatement.getText == "SELECT 1") } test("multi select without ; - should fail") { @@ -62,7 +62,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .zip(tree.collection) .foreach { case (expected, statement) => val sparkStatement = statement.asInstanceOf[SingleStatement] - val statementText = sparkStatement.getText(sqlScriptText) + val statementText = sparkStatement.getText assert(statementText == expected) } } @@ -124,7 +124,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .zip(tree.collection) .foreach { case (expected, statement) => val sparkStatement = statement.asInstanceOf[SingleStatement] - val statementText = sparkStatement.getText(sqlScriptText) + val statementText = sparkStatement.getText assert(statementText == expected) } } @@ -148,16 +148,16 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.head.isInstanceOf[CompoundBody]) val body1 = tree.collection.head.asInstanceOf[CompoundBody] assert(body1.collection.length == 1) - assert(body1.collection.head.asInstanceOf[SingleStatement].getText(sqlScriptText) + assert(body1.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") val body2 = tree.collection(1).asInstanceOf[CompoundBody] assert(body2.collection.length == 1) assert(body2.collection.head.isInstanceOf[CompoundBody]) val nestedBody = body2.collection.head.asInstanceOf[CompoundBody] - assert(nestedBody.collection.head.asInstanceOf[SingleStatement].getText(sqlScriptText) + assert(nestedBody.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 2") - assert(nestedBody.collection(1).asInstanceOf[SingleStatement].getText(sqlScriptText) + assert(nestedBody.collection(1).asInstanceOf[SingleStatement].getText == "SELECT 3") } From 3126baf9b316fbc0e2414cf5afa7cb518d0c21a9 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 01:05:56 +0200 Subject: [PATCH 08/99] Build break fix + minor styling --- .../scripting/SqlScriptingExecutionNode.scala | 25 +++++++++++-------- .../scripting/SqlScriptingInterpreter.scala | 19 +++++++++----- .../SqlScriptingExecutionNodeSuite.scala | 3 +-- .../SqlScriptingInterpreterSuite.scala | 17 ++++++------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index a210d36f34551..f585c9813dd85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} * Trait for all SQL scripting execution nodes used during interpretation phase. */ sealed trait CompoundStatementExec extends Logging { + /** * Whether the statement originates from the SQL script or is created during the interpretation. * Example: DropVariable statements are automatically created at the end of each compound. @@ -44,23 +45,25 @@ sealed trait CompoundStatementExec extends Logging { trait LeafStatementExec extends CompoundStatementExec /** - * Non-leaf node in the execution tree. - * It is an iterator over executable child nodes. + * Non-leaf node in the execution tree. It is an iterator over executable child nodes. */ trait NonLeafStatementExec extends CompoundStatementExec with Iterator[CompoundStatementExec] /** * Executable node for SingleStatement. - * @param parsedPlan Logical plan of the parsed statement. - * @param origin Origin descriptor for the statement. - * @param isInternal Whether the statement originates from the SQL script or it is created during - * the interpretation. Example: DropVariable statements are automatically created - * at the end of each compound. + * @param parsedPlan + * Logical plan of the parsed statement. + * @param origin + * Origin descriptor for the statement. + * @param isInternal + * Whether the statement originates from the SQL script or it is created during the + * interpretation. Example: DropVariable statements are automatically created at the end of each + * compound. */ class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, - override val isInternal : Boolean) + override val isInternal: Boolean) extends LeafStatementExec with WithOrigin { @@ -83,7 +86,8 @@ class SingleStatementExec( /** * Abstract class for all statements that contain nested statements. * Implements recursive iterator logic over all child execution nodes. - * @param collection Collection of child execution nodes. + * @param collection + * Collection of child execution nodes. */ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) extends NonLeafStatementExec { @@ -130,7 +134,8 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState /** * Executable node for CompoundBody. - * @param statements Executable nodes for nested statements within the CompoundBody. + * @param statements + * Executable nodes for nested statements within the CompoundBody. */ class CompoundBodyExec(statements: Seq[CompoundStatementExec]) extends CompoundNestedStatementIteratorExec(statements) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 4fbbe73d85642..755a46428d554 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -26,12 +26,15 @@ import org.apache.spark.sql.catalyst.trees.Origin * SQL scripting interpreter - builds SQL script execution plan. */ case class SqlScriptingInterpreter() { + /** * Build execution plan and return statements that need to be executed, * wrapped in the execution node. * - * @param compound CompoundBody for which to build the plan. - * @return Iterator through collection of statements to be executed. + * @param compound + * CompoundBody for which to build the plan. + * @return + * Iterator through collection of statements to be executed. */ def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec] @@ -39,8 +42,10 @@ case class SqlScriptingInterpreter() { /** * Fetch the name of the Create Variable plan. - * @param plan Plan to fetch the name from. - * @return Name of the variable. + * @param plan + * Plan to fetch the name from. + * @return + * Name of the variable. */ private def getDeclareVarNameFromPlan(plan: LogicalPlan): Option[UnresolvedIdentifier] = plan match { @@ -50,8 +55,10 @@ case class SqlScriptingInterpreter() { /** * Transform the parsed tree to the executable node. - * @param node Root node of the parsed tree. - * @return Executable statement. + * @param node + * Root node of the parsed tree. + * @return + * Executable statement. */ private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec = node match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 17e00d07d496b..e6a3f1e2bbccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -36,8 +36,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) extends CompoundNestedStatementIteratorExec(statements) - case class TestBody(statements: Seq[CompoundStatementExec]) - extends CompoundBodyExec(statements) + case class TestBody(statements: Seq[CompoundStatementExec]) extends CompoundBodyExec(statements) case class TestSparkStatementWithPlan(testVal: String) extends SingleStatementExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index d2d93d3684e50..70384ae5e2f82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.test.SharedSparkSession * SQL Scripting interpreter tests. * Output from the parser is provided to the interpreter. * Output from the interpreter (iterator over executable statements) is then checked - statements - * are executed and output DataFrames are compared with expected outputs. + * are executed and output DataFrames are compared with expected outputs. */ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Helpers - private def verifySqlScriptResult( - sqlText: String, expected: Seq[Seq[Row]]): Unit = { + private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { val interpreter = SqlScriptingInterpreter() val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) val executionPlan = interpreter.buildExecutionPlan(compoundBody) @@ -45,7 +44,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { }.toArray assert(result.length == expected.length) - result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} + result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } // Tests @@ -68,8 +67,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // select with filter - Seq(Row(1)) - ) + Seq(Row(1))) verifySqlScriptResult(sqlScript, expected) } } @@ -93,8 +91,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert #1 Seq.empty[Row], // insert #2 - Seq(Row(false)) - ) + Seq(Row(false))) verifySqlScriptResult(sqlScript, expected) } } @@ -146,7 +143,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // declare var Seq.empty[Row], // set var Seq(Row(4)), // select - Seq.empty[Row], // drop var + Seq.empty[Row] // drop var ) verifySqlScriptResult(sqlScript, expected) } @@ -166,7 +163,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, Seq.empty) } assert(e.getErrorClass === "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") - assert{e.getMessage.contains("testVarName")} + assert(e.getMessage.contains("testVarName")) } test("session vars - drop var statement") { From df563bdcf4478911d1882f471a867d6b4244fc21 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 01:09:26 +0200 Subject: [PATCH 09/99] Add minor missing comments --- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 70384ae5e2f82..b3e7cda71ade4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -67,7 +67,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // select with filter - Seq(Row(1))) + Seq(Row(1)) // select + ) verifySqlScriptResult(sqlScript, expected) } } @@ -91,7 +92,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert #1 Seq.empty[Row], // insert #2 - Seq(Row(false))) + Seq(Row(false)) // select + ) verifySqlScriptResult(sqlScript, expected) } } From 6d402fc4fa6b38f9911bfbcc80dbdedcf54964a1 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 08:04:55 +0200 Subject: [PATCH 10/99] Exclude Scripting from Connect client compatibility check --- .../sql/connect/client/CheckConnectJvmClientCompatibility.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 7bf7673a7a121..aec812e15171d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -164,6 +164,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.streaming.ui.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.test.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.util.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.scripting.*"), // Skip private[sql] constructors ProblemFilters.exclude[Problem]("org.apache.spark.sql.*.this"), From 5a2055f47b68391ff6a12e8ee0ba2ef6842c345c Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 08:47:31 +0200 Subject: [PATCH 11/99] Minor doc comment changes --- .../sql/catalyst/parser/SqlScriptingLogicalOperators.scala | 6 +++++- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index b458aeaa9c439..5b2b6ab95b459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -37,7 +37,11 @@ case class SingleStatement(parsedPlan: LogicalPlan) override val origin: Origin = CurrentOrigin.get - /** Get the SQL query text corresponding to this statement. */ + /** + * Get the SQL query text corresponding to this statement. + * @return + * SQL query text. + */ def getText: String = { assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index f585c9813dd85..e8a5d6af8e41a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -76,7 +76,11 @@ class SingleStatementExec( /** @inheritdoc */ override def reset(): Unit = consumed = false - /** Get the SQL query text corresponding to this statement. */ + /** + * Get the SQL query text corresponding to this statement. + * @return + * SQL query text. + */ def getText: String = { assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) From fcbc98d4a829ab28814f817a74f4f7e86f73a74d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 27 Jun 2024 11:15:39 +0200 Subject: [PATCH 12/99] Remove inheritdoc comment --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index e8a5d6af8e41a..f294b6dab15a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -73,7 +73,9 @@ class SingleStatementExec( */ var consumed = false - /** @inheritdoc */ + /** + * Reset execution of the current node. + */ override def reset(): Unit = consumed = false /** From a98259aa2191bcbf3029c4fcfc129ec4bb5e7a9e Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Mon, 20 May 2024 12:32:10 +0200 Subject: [PATCH 13/99] Initial changes for SQL Batch Lang Parser --- .../main/resources/error/error-states.json | 6 + .../parser/BatchLangLogicalOperators.scala | 35 ++++ .../spark/sql/errors/SqlBatchLangErrors.scala | 40 ++++ .../exceptions/SqlBatchLangException.scala | 41 ++++ .../catalyst/parser/BatchParserSuite.scala | 187 ++++++++++++++++++ 5 files changed, 309 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index dd87d6bda5f22..d84da816123c9 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4613,6 +4613,12 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0L": { + "description": "Variable declaration not allowed.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala new file mode 100644 index 0000000000000..18a7fa4a039ff --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +sealed trait BatchPlanStatement + +// Statement that is supposed to be executed against Spark. +// This can also be a Spark expression that is wrapped in a statement. +case class SparkStatementWithPlan( + parsedPlan: LogicalPlan, + sourceStart: Int, + sourceEnd: Int) + extends BatchPlanStatement { + + def getText(batch: String): String = batch.substring(sourceStart, sourceEnd) +} + +case class BatchBody(collection: List[BatchPlanStatement]) extends BatchPlanStatement \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala new file mode 100644 index 0000000000000..1e5181ef3ce36 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import org.apache.spark.sql.exceptions.SqlBatchLangException + +/** + * Object for grouping error messages thrown during parsing/interpreting phase + * of the SQL Batch Language interpreter. + */ +private[sql] object SqlBatchLangErrors extends QueryErrorsBase { + + def variableDeclarationNotAllowed(varName: String): Throwable = { + new SqlBatchLangException( + errorClass = "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.NOT_ALLOWED", + messageParameters = Map("varDeclExpr" -> varName)) + } + + def variableDeclarationOnlyAtBeginning(varName: String): Throwable = { + new SqlBatchLangException( + errorClass = "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + messageParameters = Map("varDeclExpr" -> varName)) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala new file mode 100644 index 0000000000000..883256d3b7bcc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.exceptions + +import org.apache.spark.{SparkThrowable, SparkThrowableHelper} + +class SqlBatchLangException protected ( + val message: String, + val errorClass: Option[String] = None, + val messageParameters: Map[String, String] = Map.empty, + val cause: Option[Throwable] = None, + val sourceLine: Option[Int] = None, + val sourceStart: Option[Int] = None, + val sourceEnd: Option[Int] = None) + extends Exception(message, cause.orNull) with SparkThrowable with Serializable { + + def this( + errorClass: String, + messageParameters: Map[String, String]) = + this( + message = SparkThrowableHelper.getMessage(errorClass, messageParameters), + errorClass = Some(errorClass), + messageParameters = messageParameters) + + override def getErrorClass: String = errorClass.orNull +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala new file mode 100644 index 0000000000000..e2a2969a2a52c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.exceptions.SqlBatchLangException + +class BatchParserSuite extends SparkFunSuite with SQLHelper { + import CatalystSqlParser._ + + test("single select") { + val batch = "SELECT 1;" + val tree = parseBatch(batch) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) + val sparkStatement = tree.collection.head.asInstanceOf[SparkStatementWithPlan] + assert(sparkStatement.getText(batch) == "SELECT 1;") + } + + test("single select without ;") { + val batch = "SELECT 1" + val tree = parseBatch(batch) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) + val sparkStatement = tree.collection.head.asInstanceOf[SparkStatementWithPlan] + assert(sparkStatement.getText(batch) == "SELECT 1") + } + + test("multi select without ; - should fail") { + val batch = "SELECT 1 SELECT 1" + intercept[ParseException] { + parseBatch(batch) + } + } + + test("multi select") { + val batch = "BEGIN SELECT 1;SELECT 2; END" + val tree = parseBatch(batch) + assert(tree.collection.length == 2) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + + batch.split(";") + .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .zip(tree.collection) + .foreach { case (expected, statement) => + val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] + val statementText = sparkStatement.getText(batch) + assert(statementText == expected) + } + } + + test("multi statement") { + val batch = + """ + |BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + |END""".stripMargin + val tree = parseBatch(batch) + assert(tree.collection.length == 5) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + batch.split(";") + .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .zip(tree.collection) + .foreach { case (expected, statement) => + val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] + val statementText = sparkStatement.getText(batch) + assert(statementText == expected) + } + } + + test("multi statement without ; at the end") { + val batch = + """ + |BEGIN + |SELECT 1; + |SELECT 2; + |INSERT INTO A VALUES (a, b, 3); + |SELECT a, b, c FROM T; + |SELECT * FROM T + |END""".stripMargin + val tree = parseBatch(batch) + assert(tree.collection.length == 5) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + batch.split(";") + .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .zip(tree.collection) + .foreach { case (expected, statement) => + val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] + val statementText = sparkStatement.getText(batch) + assert(statementText == expected) + } + } + + test("nested begin end") { + val batch = + """ + |BEGIN + | BEGIN + | SELECT 1; + | END; + | BEGIN + | BEGIN + | SELECT 2; + | SELECT 3; + | END; + | END; + |END""".stripMargin + val tree = parseBatch(batch) + assert(tree.collection.length == 2) + assert(tree.collection.head.isInstanceOf[BatchBody]) + val body1 = tree.collection.head.asInstanceOf[BatchBody] + assert(body1.collection.length == 1) + assert(body1.collection.head.asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 1") + + val body2 = tree.collection(1).asInstanceOf[BatchBody] + assert(body2.collection.length == 1) + assert(body2.collection.head.isInstanceOf[BatchBody]) + val nestedBody = body2.collection.head.asInstanceOf[BatchBody] + assert( + nestedBody.collection.head.asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 2") + assert( + nestedBody.collection(1).asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 3") + } + + test("variable declare and set") { + val batch = + """ + |BEGIN + |DECLARE totalInsertCount = 0; + |SET VAR totalInsertCount = totalInsertCount + 1; + |END""".stripMargin + val tree = parseBatch(batch) + + assert(tree.collection.length == 2) + assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) + assert(tree.collection(1).isInstanceOf[SparkStatementWithPlan]) + } + + test ("declare in compound top") { + val batch = + """ + |BEGIN + |DECLARE totalInsertCount = 0; + |SET VAR totalInsertCount = totalInsertCount + 1; + |BEGIN + | DECLARE totalInsertCount2 = 0; + | SET VAR totalInsertCount2 = totalInsertCount2 + 1; + |END + |END""".stripMargin + val _ = parseBatch(batch) + } + + test("declare after compound top") { + val batch = + """ + |BEGIN + |DECLARE totalInsertCount = 0; + |SET VAR totalInsertCount = totalInsertCount + 1; + |DECLARE totalInsertCount2 = 0; + |END""".stripMargin + val e = intercept[SqlBatchLangException] { + parseBatch(batch) + } + assert(e.getErrorClass === "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING") + assert(e.getMessage.contains("DECLARE totalInsertCount2 = 0;")) + } +} From 14ec8e0a5eb9188ef930c20149e59cbd4d7b4cda Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Mon, 20 May 2024 12:40:28 +0200 Subject: [PATCH 14/99] Add missing empty line --- .../spark/sql/catalyst/parser/BatchLangLogicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala index 18a7fa4a039ff..105e2e844e6d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala @@ -32,4 +32,4 @@ case class SparkStatementWithPlan( def getText(batch: String): String = batch.substring(sourceStart, sourceEnd) } -case class BatchBody(collection: List[BatchPlanStatement]) extends BatchPlanStatement \ No newline at end of file +case class BatchBody(collection: List[BatchPlanStatement]) extends BatchPlanStatement From b4638a8f9f129cd59911b6a1ab4bc736d6b8ed71 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 21 May 2024 00:02:04 +0200 Subject: [PATCH 15/99] Addressing comments --- .../main/resources/error/error-states.json | 6 -- .../spark/sql/errors/SqlBatchLangErrors.scala | 40 ------------- .../exceptions/SqlBatchLangException.scala | 41 ------------- .../catalyst/parser/BatchParserSuite.scala | 60 +++++-------------- 4 files changed, 14 insertions(+), 133 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index d84da816123c9..dd87d6bda5f22 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4613,12 +4613,6 @@ "standard": "N", "usedBy": ["Spark"] }, - "42K0L": { - "description": "Variable declaration not allowed.", - "origin": "Spark", - "standard": "N", - "usedBy": ["Spark"] - }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala deleted file mode 100644 index 1e5181ef3ce36..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlBatchLangErrors.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.errors - -import org.apache.spark.sql.exceptions.SqlBatchLangException - -/** - * Object for grouping error messages thrown during parsing/interpreting phase - * of the SQL Batch Language interpreter. - */ -private[sql] object SqlBatchLangErrors extends QueryErrorsBase { - - def variableDeclarationNotAllowed(varName: String): Throwable = { - new SqlBatchLangException( - errorClass = "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.NOT_ALLOWED", - messageParameters = Map("varDeclExpr" -> varName)) - } - - def variableDeclarationOnlyAtBeginning(varName: String): Throwable = { - new SqlBatchLangException( - errorClass = "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", - messageParameters = Map("varDeclExpr" -> varName)) - } - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala deleted file mode 100644 index 883256d3b7bcc..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlBatchLangException.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.exceptions - -import org.apache.spark.{SparkThrowable, SparkThrowableHelper} - -class SqlBatchLangException protected ( - val message: String, - val errorClass: Option[String] = None, - val messageParameters: Map[String, String] = Map.empty, - val cause: Option[Throwable] = None, - val sourceLine: Option[Int] = None, - val sourceStart: Option[Int] = None, - val sourceEnd: Option[Int] = None) - extends Exception(message, cause.orNull) with SparkThrowable with Serializable { - - def this( - errorClass: String, - messageParameters: Map[String, String]) = - this( - message = SparkThrowableHelper.getMessage(errorClass, messageParameters), - errorClass = Some(errorClass), - messageParameters = messageParameters) - - override def getErrorClass: String = errorClass.orNull -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala index e2a2969a2a52c..e3eefe90cfd3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.exceptions.SqlBatchLangException class BatchParserSuite extends SparkFunSuite with SQLHelper { import CatalystSqlParser._ @@ -44,9 +43,12 @@ class BatchParserSuite extends SparkFunSuite with SQLHelper { test("multi select without ; - should fail") { val batch = "SELECT 1 SELECT 1" - intercept[ParseException] { + val e = intercept[ParseException] { parseBatch(batch) } + assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") + assert(e.getMessage.contains("Syntax error")) + assert(e.getMessage.contains("SELECT 1 SELECT 1")) } test("multi select") { @@ -56,7 +58,7 @@ class BatchParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) batch.split(";") - .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .map(cleanupStatementString) .zip(tree.collection) .foreach { case (expected, statement) => val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] @@ -79,7 +81,7 @@ class BatchParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.length == 5) assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) batch.split(";") - .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .map(cleanupStatementString) .zip(tree.collection) .foreach { case (expected, statement) => val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] @@ -102,7 +104,7 @@ class BatchParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.length == 5) assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) batch.split(";") - .map(_.replace("\n", "").replace("BEGIN", "").replace("END", "").trim) + .map(cleanupStatementString) .zip(tree.collection) .foreach { case (expected, statement) => val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] @@ -142,46 +144,12 @@ class BatchParserSuite extends SparkFunSuite with SQLHelper { nestedBody.collection(1).asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 3") } - test("variable declare and set") { - val batch = - """ - |BEGIN - |DECLARE totalInsertCount = 0; - |SET VAR totalInsertCount = totalInsertCount + 1; - |END""".stripMargin - val tree = parseBatch(batch) - - assert(tree.collection.length == 2) - assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) - assert(tree.collection(1).isInstanceOf[SparkStatementWithPlan]) - } - - test ("declare in compound top") { - val batch = - """ - |BEGIN - |DECLARE totalInsertCount = 0; - |SET VAR totalInsertCount = totalInsertCount + 1; - |BEGIN - | DECLARE totalInsertCount2 = 0; - | SET VAR totalInsertCount2 = totalInsertCount2 + 1; - |END - |END""".stripMargin - val _ = parseBatch(batch) - } - - test("declare after compound top") { - val batch = - """ - |BEGIN - |DECLARE totalInsertCount = 0; - |SET VAR totalInsertCount = totalInsertCount + 1; - |DECLARE totalInsertCount2 = 0; - |END""".stripMargin - val e = intercept[SqlBatchLangException] { - parseBatch(batch) - } - assert(e.getErrorClass === "SQL_BATCH_LANG_INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING") - assert(e.getMessage.contains("DECLARE totalInsertCount2 = 0;")) + // Helper methods + def cleanupStatementString(statementStr: String): String = { + statementStr + .replace("\n", "") + .replace("BEGIN", "") + .replace("END", "") + .trim } } From 57c446e31d09b77affbd600328c073a24fe48ba7 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 21 May 2024 13:50:10 +0200 Subject: [PATCH 16/99] Further improvements --- .../parser/BatchLangLogicalOperators.scala | 35 ---- .../catalyst/parser/BatchParserSuite.scala | 155 ------------------ 2 files changed, 190 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala deleted file mode 100644 index 105e2e844e6d0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/BatchLangLogicalOperators.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -sealed trait BatchPlanStatement - -// Statement that is supposed to be executed against Spark. -// This can also be a Spark expression that is wrapped in a statement. -case class SparkStatementWithPlan( - parsedPlan: LogicalPlan, - sourceStart: Int, - sourceEnd: Int) - extends BatchPlanStatement { - - def getText(batch: String): String = batch.substring(sourceStart, sourceEnd) -} - -case class BatchBody(collection: List[BatchPlanStatement]) extends BatchPlanStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala deleted file mode 100644 index e3eefe90cfd3e..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/BatchParserSuite.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.SQLHelper - -class BatchParserSuite extends SparkFunSuite with SQLHelper { - import CatalystSqlParser._ - - test("single select") { - val batch = "SELECT 1;" - val tree = parseBatch(batch) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) - val sparkStatement = tree.collection.head.asInstanceOf[SparkStatementWithPlan] - assert(sparkStatement.getText(batch) == "SELECT 1;") - } - - test("single select without ;") { - val batch = "SELECT 1" - val tree = parseBatch(batch) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[SparkStatementWithPlan]) - val sparkStatement = tree.collection.head.asInstanceOf[SparkStatementWithPlan] - assert(sparkStatement.getText(batch) == "SELECT 1") - } - - test("multi select without ; - should fail") { - val batch = "SELECT 1 SELECT 1" - val e = intercept[ParseException] { - parseBatch(batch) - } - assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") - assert(e.getMessage.contains("Syntax error")) - assert(e.getMessage.contains("SELECT 1 SELECT 1")) - } - - test("multi select") { - val batch = "BEGIN SELECT 1;SELECT 2; END" - val tree = parseBatch(batch) - assert(tree.collection.length == 2) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) - - batch.split(";") - .map(cleanupStatementString) - .zip(tree.collection) - .foreach { case (expected, statement) => - val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] - val statementText = sparkStatement.getText(batch) - assert(statementText == expected) - } - } - - test("multi statement") { - val batch = - """ - |BEGIN - | SELECT 1; - | SELECT 2; - | INSERT INTO A VALUES (a, b, 3); - | SELECT a, b, c FROM T; - | SELECT * FROM T; - |END""".stripMargin - val tree = parseBatch(batch) - assert(tree.collection.length == 5) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) - batch.split(";") - .map(cleanupStatementString) - .zip(tree.collection) - .foreach { case (expected, statement) => - val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] - val statementText = sparkStatement.getText(batch) - assert(statementText == expected) - } - } - - test("multi statement without ; at the end") { - val batch = - """ - |BEGIN - |SELECT 1; - |SELECT 2; - |INSERT INTO A VALUES (a, b, 3); - |SELECT a, b, c FROM T; - |SELECT * FROM T - |END""".stripMargin - val tree = parseBatch(batch) - assert(tree.collection.length == 5) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) - batch.split(";") - .map(cleanupStatementString) - .zip(tree.collection) - .foreach { case (expected, statement) => - val sparkStatement = statement.asInstanceOf[SparkStatementWithPlan] - val statementText = sparkStatement.getText(batch) - assert(statementText == expected) - } - } - - test("nested begin end") { - val batch = - """ - |BEGIN - | BEGIN - | SELECT 1; - | END; - | BEGIN - | BEGIN - | SELECT 2; - | SELECT 3; - | END; - | END; - |END""".stripMargin - val tree = parseBatch(batch) - assert(tree.collection.length == 2) - assert(tree.collection.head.isInstanceOf[BatchBody]) - val body1 = tree.collection.head.asInstanceOf[BatchBody] - assert(body1.collection.length == 1) - assert(body1.collection.head.asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 1") - - val body2 = tree.collection(1).asInstanceOf[BatchBody] - assert(body2.collection.length == 1) - assert(body2.collection.head.isInstanceOf[BatchBody]) - val nestedBody = body2.collection.head.asInstanceOf[BatchBody] - assert( - nestedBody.collection.head.asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 2") - assert( - nestedBody.collection(1).asInstanceOf[SparkStatementWithPlan].getText(batch) == "SELECT 3") - } - - // Helper methods - def cleanupStatementString(statementStr: String): String = { - statementStr - .replace("\n", "") - .replace("BEGIN", "") - .replace("END", "") - .trim - } -} From d6630249a27167796a9de399dce9861f4940c7c9 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 29 May 2024 16:45:29 +0200 Subject: [PATCH 17/99] Temp changes for interpreter --- .../SqlScriptingIntegrationSuite.scala | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala new file mode 100644 index 0000000000000..fa5b47def9740 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.sql.{Dataset, QueryTest, Row} +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.test.SharedSparkSession + +class SqlScriptingIntegrationSuite extends QueryTest with SharedSparkSession { + // Helpers + private def verifySqlScriptResult( + sqlText: String, expected: Seq[Seq[Row]], printResult: Boolean = false): Unit = { + val interpreter = SqlScriptingInterpreter() + val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) + val executionPlan = interpreter.buildExecutionPlan(compoundBody, DataFrameEvaluator(spark)) + val result = executionPlan.flatMap { + case statement: SparkStatementWithPlanExec => + if (printResult) { + // scalastyle:off println + println("Executing: " + statement.getText(sqlText)) + // scalastyle:on println + } + + if (statement.consumed) { + None + } else { + Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } + case _ => None + }.toArray + + assert(result.length == expected.length) + result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} + } + + // Tests + test("select 1") { + verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) + } + + test("multi statement - simple") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT a, b FROM t WHERE a = 12; + |SELECT a FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // select with filter + Seq(Row(1)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("multi statement - count") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT + | CASE WHEN COUNT(*) > 10 THEN true + | ELSE false + | END AS MoreThanTen + |FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert #1 + Seq.empty[Row], // insert #2 + Seq(Row(false)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("session vars - set and read") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(2)), // select + Seq.empty[Row] // drop var + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read scoped") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE var = 1; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 2; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 3; + | SET VAR var = var + 1; + | SELECT var; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq(Row(2)), // select + Seq.empty[Row], // drop var + Seq.empty[Row], // declare var + Seq.empty[Row], // set var + Seq(Row(4)), // select + Seq.empty[Row], // drop var + ) + verifySqlScriptResult(sqlScript, expected) + } +} \ No newline at end of file From 3989e12fb79f0c0fc11f63275d04482749650e36 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 28 Jun 2024 14:41:07 +0200 Subject: [PATCH 18/99] Add tests --- .../sql/catalyst/parser/SqlBaseParser.g4 | 14 +++++- .../sql/catalyst/parser/AstBuilder.scala | 24 ++++++++-- .../parser/SqlScriptingLogicalOperators.scala | 4 +- .../parser/SqlScriptingParserSuite.scala | 47 +++++++++++++++++++ 4 files changed, 82 insertions(+), 7 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index ff863565910da..558812409fc55 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -52,7 +52,7 @@ singleCompoundStatement ; beginEndCompoundBlock - : BEGIN compoundBody END + : beginLabel? BEGIN compoundBody END endLabel? ; compoundBody @@ -68,6 +68,18 @@ singleStatement : statement SEMICOLON* EOF ; +label + : multipartIdentifier + ; + +beginLabel + : label COLON + ; + +endLabel + : label + ; + singleExpression : namedExpression EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bca2c87253946..4bcead03a9530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -121,7 +121,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(s).asInstanceOf[CompoundBody] }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan))) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), java.util.UUID.randomUUID.toString) } } @@ -129,20 +129,34 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody] } - private def visitCompoundBodyImpl(ctx: CompoundBodyContext): CompoundBody = { + private def visitCompoundBodyImpl(ctx: CompoundBodyContext, label: String): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() ctx.compoundStatements.forEach(compoundStatement => { buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] }) - CompoundBody(buff.toSeq) + CompoundBody(buff.toSeq, label) } override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = { - visitCompoundBodyImpl(ctx.compoundBody()) + val beginLabelCtx = Option(ctx.beginLabel()) + val endLabelCtx = Option(ctx.endLabel()) + + (beginLabelCtx, endLabelCtx) match { + case (Some(bl: BeginLabelContext), Some(el: EndLabelContext)) + if bl.label().getText.nonEmpty && bl.label().getText != el.label().getText => + throw SparkException.internalError("Both labels should be same.") + case (None, Some(_)) => + throw SparkException.internalError("End label can't exist without begin label.") + case _ => + } + + val labelText = beginLabelCtx. + map(_.label().getText).getOrElse(java.util.UUID.randomUUID.toString) + visitCompoundBodyImpl(ctx.compoundBody(), labelText) } override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = { - visitCompoundBodyImpl(ctx) + visitCompoundBodyImpl(ctx, "") } override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 5b2b6ab95b459..adba3dc42115d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -52,4 +52,6 @@ case class SingleStatement(parsedPlan: LogicalPlan) * Logical operator for a compound body. Contains all statements within the compound body. * @param collection Collection of statements within the compound body. */ -case class CompoundBody(collection: Seq[CompoundPlanStatement]) extends CompoundPlanStatement +case class CompoundBody( + collection: Seq[CompoundPlanStatement], + label: String) extends CompoundPlanStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index aa72e409f528b..d3d4e106bb19b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -161,6 +161,53 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { == "SELECT 3") } + test("compound: beginLabel") { + val batch = + """ + |lbl: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + |END""".stripMargin + val tree = parseScript(batch) + assert(tree.collection.length == 5) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + assert(tree.label.equals("lbl")) + } + + test("compound: beginLabel + endlLabel") { + val batch = + """ + |lbl: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + |END lbl""".stripMargin + val tree = parseScript(batch) + assert(tree.collection.length == 5) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + assert(tree.label.equals("lbl")) + } + + test("compound: beginLabel + endlLabel with different values") { + val batch = + """ + |lbl_begin: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + |END lbl_end""".stripMargin + val tree = parseScript(batch) + assert(tree.collection.length == 5) + assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr From 623cd2b94edf9bea8f725b92a86c7978d509056e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 28 Jun 2024 14:41:23 +0200 Subject: [PATCH 19/99] Add error tests --- .../parser/SqlScriptingParserSuite.scala | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index d3d4e106bb19b..1faa054a5b670 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.SQLHelper class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { @@ -203,9 +203,28 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl_end""".stripMargin - val tree = parseScript(batch) - assert(tree.collection.length == 5) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + val e = intercept[SparkException] { + parseScript(batch) + } + assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getMessage.contains("Both labels should be same.")) + } + + test("compound: endlLabel") { + val batch = + """ + |BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + |END lbl""".stripMargin + val e = intercept[SparkException] { + parseScript(batch) + } + assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getMessage.contains("End label can't exist without begin label.")) } // Helper methods From 214c685eb13df57b6b61f89a3eb26d60ac2c0ff5 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 28 Jun 2024 14:57:48 +0200 Subject: [PATCH 20/99] Add default value for CompoundBody constructor --- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index f294b6dab15a5..7250de5726482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -143,5 +143,5 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * @param statements * Executable nodes for nested statements within the CompoundBody. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec]) +class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: String = "") extends CompoundNestedStatementIteratorExec(statements) From e888bab137bab8257191e5f4d2896f861dba8237 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Jul 2024 13:05:45 +0200 Subject: [PATCH 21/99] Fix scalastyle pt1 --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4bcead03a9530..001edcb7d6b61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -121,7 +121,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(s).asInstanceOf[CompoundBody] }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), java.util.UUID.randomUUID.toString) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), + java.util.UUID.randomUUID.toString) } } From b417ddf37430672dd93370f34ef75db2aaba28b6 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Jul 2024 13:15:43 +0200 Subject: [PATCH 22/99] Fix SqlScriptingParserSuite --- .../spark/sql/catalyst/parser/SqlScriptingParserSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 1faa054a5b670..375303782a114 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -173,7 +173,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |END""".stripMargin val tree = parseScript(batch) assert(tree.collection.length == 5) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) assert(tree.label.equals("lbl")) } @@ -189,7 +189,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |END lbl""".stripMargin val tree = parseScript(batch) assert(tree.collection.length == 5) - assert(tree.collection.forall(_.isInstanceOf[SparkStatementWithPlan])) + assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) assert(tree.label.equals("lbl")) } From 55e34d70c02dbb08ce64462775e0eb1e640b0ffd Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Jul 2024 13:55:12 +0200 Subject: [PATCH 23/99] Remove IntegrationSuite --- .../SqlScriptingIntegrationSuite.scala | 153 ------------------ 1 file changed, 153 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala deleted file mode 100644 index fa5b47def9740..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingIntegrationSuite.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.scripting - -import org.apache.spark.sql.{Dataset, QueryTest, Row} -import org.apache.spark.sql.catalyst.QueryPlanningTracker -import org.apache.spark.sql.test.SharedSparkSession - -class SqlScriptingIntegrationSuite extends QueryTest with SharedSparkSession { - // Helpers - private def verifySqlScriptResult( - sqlText: String, expected: Seq[Seq[Row]], printResult: Boolean = false): Unit = { - val interpreter = SqlScriptingInterpreter() - val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) - val executionPlan = interpreter.buildExecutionPlan(compoundBody, DataFrameEvaluator(spark)) - val result = executionPlan.flatMap { - case statement: SparkStatementWithPlanExec => - if (printResult) { - // scalastyle:off println - println("Executing: " + statement.getText(sqlText)) - // scalastyle:on println - } - - if (statement.consumed) { - None - } else { - Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) - } - case _ => None - }.toArray - - assert(result.length == expected.length) - result.zip(expected).foreach{ case (df, expectedAnswer) => checkAnswer(df, expectedAnswer)} - } - - // Tests - test("select 1") { - verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) - } - - test("multi statement - simple") { - withTable("t") { - val sqlScript = - """ - |BEGIN - |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; - |INSERT INTO t VALUES (1, 'a', 1.0); - |SELECT a, b FROM t WHERE a = 12; - |SELECT a FROM t; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // select with filter - Seq(Row(1)) - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("multi statement - count") { - withTable("t") { - val sqlScript = - """ - |BEGIN - |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; - |INSERT INTO t VALUES (1, 'a', 1.0); - |INSERT INTO t VALUES (1, 'a', 1.0); - |SELECT - | CASE WHEN COUNT(*) > 10 THEN true - | ELSE false - | END AS MoreThanTen - |FROM t; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert #1 - Seq.empty[Row], // insert #2 - Seq(Row(false)) - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("session vars - set and read") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var - ) - verifySqlScriptResult(sqlScript, expected) - } - - test("session vars - set and read scoped") { - val sqlScript = - """ - |BEGIN - | BEGIN - | DECLARE var = 1; - | SELECT var; - | END; - | BEGIN - | DECLARE var = 2; - | SELECT var; - | END; - | BEGIN - | DECLARE var = 3; - | SET VAR var = var + 1; - | SELECT var; - | END; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(4)), // select - Seq.empty[Row], // drop var - ) - verifySqlScriptResult(sqlScript, expected) - } -} \ No newline at end of file From 45294b06d78511388520e41f3117e2659dbd9935 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Jul 2024 16:00:31 +0200 Subject: [PATCH 24/99] Remove unnecessary level in parser --- .../org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 ++------ .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 7 ++++--- .../sql/catalyst/parser/SqlScriptingParserSuite.scala | 6 +++--- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index d1872f909d2f6..433fdbccf3abe 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -68,16 +68,12 @@ singleStatement : statement SEMICOLON* EOF ; -label - : multipartIdentifier - ; - beginLabel - : label COLON + : multipartIdentifier COLON ; endLabel - : label + : multipartIdentifier ; singleExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ede41e3679012..0a2f841f8a03a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -144,15 +144,16 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { (beginLabelCtx, endLabelCtx) match { case (Some(bl: BeginLabelContext), Some(el: EndLabelContext)) - if bl.label().getText.nonEmpty && bl.label().getText != el.label().getText => - throw SparkException.internalError("Both labels should be same.") + if bl.multipartIdentifier().getText.nonEmpty && + bl.multipartIdentifier().getText != el.multipartIdentifier().getText => + throw SparkException.internalError("Both labels should be same.") case (None, Some(_)) => throw SparkException.internalError("End label can't exist without begin label.") case _ => } val labelText = beginLabelCtx. - map(_.label().getText).getOrElse(java.util.UUID.randomUUID.toString) + map(_.multipartIdentifier().getText).getOrElse(java.util.UUID.randomUUID.toString) visitCompoundBodyImpl(ctx.compoundBody(), labelText) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 375303782a114..7b485f58d2a9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -177,7 +177,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.label.equals("lbl")) } - test("compound: beginLabel + endlLabel") { + test("compound: beginLabel + endLabel") { val batch = """ |lbl: BEGIN @@ -193,7 +193,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.label.equals("lbl")) } - test("compound: beginLabel + endlLabel with different values") { + test("compound: beginLabel + endLabel with different values") { val batch = """ |lbl_begin: BEGIN @@ -210,7 +210,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(e.getMessage.contains("Both labels should be same.")) } - test("compound: endlLabel") { + test("compound: endLabel") { val batch = """ |BEGIN From e65123b748175132a5040705d0ce8b87fe764733 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Jul 2024 16:01:55 +0200 Subject: [PATCH 25/99] Add default argument value for label to visitCompoundBodyImpl --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0a2f841f8a03a..7910b4de53349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -130,7 +130,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody] } - private def visitCompoundBodyImpl(ctx: CompoundBodyContext, label: String): CompoundBody = { + private def visitCompoundBodyImpl(ctx: CompoundBodyContext, label: String = ""): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() ctx.compoundStatements.forEach(compoundStatement => { buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] @@ -158,7 +158,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = { - visitCompoundBodyImpl(ctx, "") + visitCompoundBodyImpl(ctx) } override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement = From 435454db0ae8e38cba101736c8496a405e1f0858 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 8 Jul 2024 10:07:52 +0200 Subject: [PATCH 26/99] Add logical operators and lexer changes --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 8 ++++++++ .../parser/SqlScriptingLogicalOperators.scala | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 85a4633e80502..a6577a3d9596a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -172,8 +172,10 @@ COMPACTIONS: 'COMPACTIONS'; COMPENSATION: 'COMPENSATION'; COMPUTE: 'COMPUTE'; CONCATENATE: 'CONCATENATE'; +CONDITION: 'CONDITION'; CONSTRAINT: 'CONSTRAINT'; CONTAINS: 'CONTAINS'; +CONTINUE: 'CONTINUE'; COST: 'COST'; CREATE: 'CREATE'; CROSS: 'CROSS'; @@ -223,6 +225,7 @@ EXCEPT: 'EXCEPT'; EXCHANGE: 'EXCHANGE'; EXCLUDE: 'EXCLUDE'; EXISTS: 'EXISTS'; +EXIT: 'EXIT'; EXPLAIN: 'EXPLAIN'; EXPORT: 'EXPORT'; EXTENDED: 'EXTENDED'; @@ -249,6 +252,7 @@ GLOBAL: 'GLOBAL'; GRANT: 'GRANT'; GROUP: 'GROUP'; GROUPING: 'GROUPING'; +HANDLER: 'HANDLER'; HAVING: 'HAVING'; BINARY_HEX: 'X'; HOUR: 'HOUR'; @@ -569,6 +573,10 @@ IDENTIFIER | UNICODE_LETTER+ '://' (UNICODE_LETTER | DIGIT | '_' | '/' | '-' | '.' | '?' | '=' | '&' | '#' | '%')+ ; +SQLSTATE + : [0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z] + ; + BACKQUOTED_IDENTIFIER : '`' ( ~'`' | '``' )* '`' ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index adba3dc42115d..263786e22a02a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -55,3 +55,21 @@ case class SingleStatement(parsedPlan: LogicalPlan) case class CompoundBody( collection: Seq[CompoundPlanStatement], label: String) extends CompoundPlanStatement + +/** + * Logical operator for an error condition. + * @param sqlstate SQLSTATE. + * @param conditionName Name of the error condition. + */ +case class ErrorCondition( + sqlstate: String, + conditionName: String) extends CompoundPlanStatement + +/** + * Logical operator for an error condition. + * @param conditionName Name of the error condition variable for which the handler is built. + * @param body CompoundBody of the handler. + */ +case class ErrorHandler( + conditionName: String, + body: CompoundBody) extends CompoundPlanStatement From 5d95075513306e907825c6c1c220a3f92a27e627 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 8 Jul 2024 17:31:47 +0200 Subject: [PATCH 27/99] Add changes to visitor nodes and grammar --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 4 ++- .../sql/catalyst/parser/SqlBaseParser.g4 | 18 +++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 31 +++++++++++++++++++ .../parser/SqlScriptingLogicalOperators.scala | 21 +++++++++---- .../scripting/SqlScriptingExecutionNode.scala | 3 ++ .../scripting/SqlScriptingInterpreter.scala | 2 ++ 6 files changed, 72 insertions(+), 7 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index a6577a3d9596a..3c4e2d1ba23ec 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -405,6 +405,8 @@ SORTED: 'SORTED'; SOURCE: 'SOURCE'; SPECIFIC: 'SPECIFIC'; SQL: 'SQL'; +SQLEXCEPTION: 'SQLEXCEPTION'; +SQLSTATE: 'SQLSTATE'; START: 'START'; STATISTICS: 'STATISTICS'; STORED: 'STORED'; @@ -573,7 +575,7 @@ IDENTIFIER | UNICODE_LETTER+ '://' (UNICODE_LETTER | DIGIT | '_' | '/' | '-' | '.' | '?' | '=' | '&' | '#' | '%')+ ; -SQLSTATE +SQLSTATE_VALUE : [0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z] ; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 433fdbccf3abe..055d0d8487a5e 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -62,12 +62,30 @@ compoundBody compoundStatement : statement | beginEndCompoundBlock + | declareCondition + | declareHandler ; singleStatement : statement SEMICOLON* EOF ; +conditionValue + : SQLSTATE SQLSTATE_VALUE + ; + +conditionValueList + : ((conditionValues+=conditionValue (COMMA conditionValues+=conditionValue)*) | SQLEXCEPTION) + ; + +declareCondition + : DECLARE multipartIdentifier CONDITION FOR conditionValue + ; + +declareHandler + : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement) + ; + beginLabel : multipartIdentifier COLON ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7910b4de53349..d10980f388b57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -170,6 +170,37 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } } + override def visitConditionValue(ctx: ConditionValueContext): String = { + ctx.SQLSTATE_VALUE().getText + } + + override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { + if (ctx.SQLEXCEPTION() != null) { + return Seq("SQLEXCEPTION") + } + + val buff = ListBuffer[String]() + ctx.conditionValues.forEach(conditionValue => { + buff += visit(conditionValue).asInstanceOf[String] + }) + buff.toSeq + } + + override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { + val conditionName = ctx.multipartIdentifier().getText + val conditionValue = visit(ctx.conditionValue()).asInstanceOf[String] + + ErrorCondition(conditionName, conditionValue) + } + + override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { + val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]] + val body = visit(ctx.compoundBody()).asInstanceOf[CompoundBody] + val handlerType = if (ctx.EXIT() != null) HandlerType.EXIT else HandlerType.CONTINUE + + ErrorHandler(conditions, body, handlerType) + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { visit(ctx.statement).asInstanceOf[LogicalPlan] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 263786e22a02a..d06e9bccd29eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} @@ -58,18 +59,26 @@ case class CompoundBody( /** * Logical operator for an error condition. - * @param sqlstate SQLSTATE. * @param conditionName Name of the error condition. + * @param value SQLSTATE or Error Code. */ case class ErrorCondition( - sqlstate: String, - conditionName: String) extends CompoundPlanStatement + conditionName: String, + value: String) extends CompoundPlanStatement + +object HandlerType extends Enumeration { + type HandlerType = Value + val EXIT, CONTINUE = Value +} /** * Logical operator for an error condition. - * @param conditionName Name of the error condition variable for which the handler is built. + * @param conditions Name of the error condition variable for which the handler is built. * @param body CompoundBody of the handler. + * @param handlerType Type of the handler (CONTINUE or EXIT). */ case class ErrorHandler( - conditionName: String, - body: CompoundBody) extends CompoundPlanStatement + conditions: Seq[String], + body: CompoundBody, + handlerType: HandlerType) extends CompoundPlanStatement + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7250de5726482..d5755b7fcbafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -145,3 +145,6 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState */ class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: String = "") extends CompoundNestedStatementIteratorExec(statements) + +class HandlerExec(statements: Seq[CompoundStatementExec]) + extends CompoundNestedStatementIteratorExec(statements) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 755a46428d554..19a824d8a064b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -79,5 +79,7 @@ case class SqlScriptingInterpreter() { sparkStatement.parsedPlan, sparkStatement.origin, isInternal = false) + case _ => + throw new UnsupportedOperationException(s"Unsupported statement type: $node") } } From 83f6651d4eeab09da70c3eb083712a1525bbb92b Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 10 Jul 2024 14:37:04 +0200 Subject: [PATCH 28/99] Fix grammar and add AST visitors --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 6 ++--- .../sql/catalyst/parser/SqlBaseParser.g4 | 8 ++++--- .../sql/catalyst/parser/AstBuilder.scala | 22 +++++++++---------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 3c4e2d1ba23ec..ab3a859273240 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -243,6 +243,7 @@ FOR: 'FOR'; FOREIGN: 'FOREIGN'; FORMAT: 'FORMAT'; FORMATTED: 'FORMATTED'; +FOUND: 'FOUND'; FROM: 'FROM'; FULL: 'FULL'; FUNCTION: 'FUNCTION'; @@ -406,7 +407,6 @@ SOURCE: 'SOURCE'; SPECIFIC: 'SPECIFIC'; SQL: 'SQL'; SQLEXCEPTION: 'SQLEXCEPTION'; -SQLSTATE: 'SQLSTATE'; START: 'START'; STATISTICS: 'STATISTICS'; STORED: 'STORED'; @@ -575,8 +575,8 @@ IDENTIFIER | UNICODE_LETTER+ '://' (UNICODE_LETTER | DIGIT | '_' | '/' | '-' | '.' | '?' | '=' | '&' | '#' | '%')+ ; -SQLSTATE_VALUE - : [0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z] +SQLSTATE + : '\'' [0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z] '\'' ; BACKQUOTED_IDENTIFIER diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 055d0d8487a5e..db2a1717ccf9d 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -71,15 +71,17 @@ singleStatement ; conditionValue - : SQLSTATE SQLSTATE_VALUE + : SQLSTATE + | multipartIdentifier ; + conditionValueList - : ((conditionValues+=conditionValue (COMMA conditionValues+=conditionValue)*) | SQLEXCEPTION) + : ((conditionValues+=conditionValue (COMMA conditionValues+=conditionValue)*) | SQLEXCEPTION | NOT FOUND) ; declareCondition - : DECLARE multipartIdentifier CONDITION FOR conditionValue + : DECLARE multipartIdentifier CONDITION (FOR SQLSTATE)? ; declareHandler diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d10980f388b57..7d01094c90e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -171,24 +171,24 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitConditionValue(ctx: ConditionValueContext): String = { - ctx.SQLSTATE_VALUE().getText + Option(ctx.multipartIdentifier()).map(_.getText).getOrElse(ctx.SQLSTATE().getText) } override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { - if (ctx.SQLEXCEPTION() != null) { - return Seq("SQLEXCEPTION") + Option(ctx.SQLEXCEPTION()).map(_ => Seq("SQLEXCEPTION")).getOrElse { + Option(ctx.NOT()).map(_ => Seq("NOT FOUND")).getOrElse { + val buff = ListBuffer[String]() + ctx.conditionValues.forEach { conditionValue => + buff += visit(conditionValue).asInstanceOf[String] + } + buff.toSeq + } } - - val buff = ListBuffer[String]() - ctx.conditionValues.forEach(conditionValue => { - buff += visit(conditionValue).asInstanceOf[String] - }) - buff.toSeq } override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { val conditionName = ctx.multipartIdentifier().getText - val conditionValue = visit(ctx.conditionValue()).asInstanceOf[String] + val conditionValue = Option(ctx.SQLSTATE()).map(_.getText).getOrElse("45000") ErrorCondition(conditionName, conditionValue) } @@ -196,7 +196,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]] val body = visit(ctx.compoundBody()).asInstanceOf[CompoundBody] - val handlerType = if (ctx.EXIT() != null) HandlerType.EXIT else HandlerType.CONTINUE + val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) ErrorHandler(conditions, body, handlerType) } From f497be9458761f24919dba0e1a9bf4c625de2fc7 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 10 Jul 2024 17:09:01 +0200 Subject: [PATCH 29/99] Add grammar changes for handlers --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 2 +- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 - .../sql/catalyst/parser/AstBuilder.scala | 12 ++++- .../parser/SqlScriptingParserSuite.scala | 50 ++++++++++++++++--- 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index ab3a859273240..a81b1e7b733c4 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -576,7 +576,7 @@ IDENTIFIER ; SQLSTATE - : '\'' [0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z][0-9A-Z] '\'' + : '\'' (LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT) '\'' ; BACKQUOTED_IDENTIFIER diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index db2a1717ccf9d..5513a07af93a8 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -75,7 +75,6 @@ conditionValue | multipartIdentifier ; - conditionValueList : ((conditionValues+=conditionValue (COMMA conditionValues+=conditionValue)*) | SQLEXCEPTION | NOT FOUND) ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7d01094c90e89..e671ee56a86bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -166,7 +166,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { Option(ctx.statement()).map {s => SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan]) }.getOrElse { - visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundPlanStatement] + val stmt = Option(ctx.beginEndCompoundBlock()). + getOrElse(Option(ctx.declareHandler()).getOrElse(ctx.declareCondition())) + visit(stmt).asInstanceOf[CompoundPlanStatement] } } @@ -195,7 +197,13 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]] - val body = visit(ctx.compoundBody()).asInstanceOf[CompoundBody] + + val body = Option(ctx.compoundBody()).map(visit).getOrElse { + val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), + java.util.UUID.randomUUID.toString) + }.asInstanceOf[CompoundBody] + val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) ErrorHandler(conditions, body, handlerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 7b485f58d2a9a..8e4f902d55a26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -162,7 +162,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } test("compound: beginLabel") { - val batch = + val sqlScriptText = """ |lbl: BEGIN | SELECT 1; @@ -171,14 +171,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END""".stripMargin - val tree = parseScript(batch) + val tree = parseScript(sqlScriptText) assert(tree.collection.length == 5) assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) assert(tree.label.equals("lbl")) } test("compound: beginLabel + endLabel") { - val batch = + val sqlScriptText = """ |lbl: BEGIN | SELECT 1; @@ -187,14 +187,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl""".stripMargin - val tree = parseScript(batch) + val tree = parseScript(sqlScriptText) assert(tree.collection.length == 5) assert(tree.collection.forall(_.isInstanceOf[SingleStatement])) assert(tree.label.equals("lbl")) } test("compound: beginLabel + endLabel with different values") { - val batch = + val sqlScriptText = """ |lbl_begin: BEGIN | SELECT 1; @@ -204,14 +204,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT * FROM T; |END lbl_end""".stripMargin val e = intercept[SparkException] { - parseScript(batch) + parseScript(sqlScriptText) } assert(e.getErrorClass === "INTERNAL_ERROR") assert(e.getMessage.contains("Both labels should be same.")) } test("compound: endLabel") { - val batch = + val sqlScriptText = """ |BEGIN | SELECT 1; @@ -221,12 +221,46 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT * FROM T; |END lbl""".stripMargin val e = intercept[SparkException] { - parseScript(batch) + parseScript(sqlScriptText) } assert(e.getErrorClass === "INTERNAL_ERROR") assert(e.getMessage.contains("End label can't exist without begin label.")) } + test("declare condition: default sqlstate") { + val sqlScriptText = + """ + |BEGIN + | DECLARE test CONDITION; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ErrorCondition]) + assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) + } + + test("declare handler") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test BEGIN SELECT 1; END; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ErrorHandler]) + } + + test("declare handler single statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test SELECT 1; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ErrorHandler]) + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr From f0c202649036f2f47af7510feee4cd54e5101f78 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 10:25:05 +0200 Subject: [PATCH 30/99] Add visit for conditions and handlers --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 6 +++++- .../sql/catalyst/parser/SqlScriptingLogicalOperators.scala | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e671ee56a86bb..5c63024d70a00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -132,8 +132,12 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { private def visitCompoundBodyImpl(ctx: CompoundBodyContext, label: String = ""): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() + val handlers = ListBuffer[ErrorHandler]() ctx.compoundStatements.forEach(compoundStatement => { - buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement] + Option(compoundStatement.declareCondition()).map(visit).foreach(handlers += _) + Option(compoundStatement.declareCondition()).map(visit).foreach(buff += _) + Option(compoundStatement.statement()).map(visit).foreach(buff += _) + Option(compoundStatement.beginEndCompoundBlock()).map(visit).foreach(buff += _) }) CompoundBody(buff.toSeq, label) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index d06e9bccd29eb..05b4aeb4f0c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -55,7 +55,8 @@ case class SingleStatement(parsedPlan: LogicalPlan) */ case class CompoundBody( collection: Seq[CompoundPlanStatement], - label: String) extends CompoundPlanStatement + label: String, + handlers: Seq[ErrorHandler]) extends CompoundPlanStatement /** * Logical operator for an error condition. From 84889047df13dc7eebe7397afaf6a7feaa893a49 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 14:31:31 +0200 Subject: [PATCH 31/99] Change grammar --- .../sql/catalyst/parser/SqlBaseParser.g4 | 4 +-- .../sql/catalyst/parser/AstBuilder.scala | 25 +++++++++++-------- .../parser/SqlScriptingParserSuite.scala | 9 ++++--- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 5513a07af93a8..9dae402ebb1f6 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -71,7 +71,7 @@ singleStatement ; conditionValue - : SQLSTATE + : stringLit | multipartIdentifier ; @@ -80,7 +80,7 @@ conditionValueList ; declareCondition - : DECLARE multipartIdentifier CONDITION (FOR SQLSTATE)? + : DECLARE multipartIdentifier CONDITION (FOR stringLit)? ; declareHandler diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5c63024d70a00..207d3f26223d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -122,7 +122,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), - java.util.UUID.randomUUID.toString) + java.util.UUID.randomUUID.toString, Seq()) } } @@ -134,12 +134,16 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val buff = ListBuffer[CompoundPlanStatement]() val handlers = ListBuffer[ErrorHandler]() ctx.compoundStatements.forEach(compoundStatement => { - Option(compoundStatement.declareCondition()).map(visit).foreach(handlers += _) - Option(compoundStatement.declareCondition()).map(visit).foreach(buff += _) - Option(compoundStatement.statement()).map(visit).foreach(buff += _) - Option(compoundStatement.beginEndCompoundBlock()).map(visit).foreach(buff += _) + Option(compoundStatement.declareHandler()).map(visit). + foreach(handlers += _.asInstanceOf[ErrorHandler]) + Option(compoundStatement.declareCondition()).map(visit). + foreach(buff += _.asInstanceOf[ErrorCondition]) + Option(compoundStatement.statement()).map(visit). + foreach(buff += _.asInstanceOf[CompoundPlanStatement]) + Option(compoundStatement.beginEndCompoundBlock()).map(visit). + foreach(buff += _.asInstanceOf[CompoundPlanStatement]) }) - CompoundBody(buff.toSeq, label) + CompoundBody(buff.toSeq, label, handlers.toSeq) } override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = { @@ -177,7 +181,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitConditionValue(ctx: ConditionValueContext): String = { - Option(ctx.multipartIdentifier()).map(_.getText).getOrElse(ctx.SQLSTATE().getText) + Option(ctx.multipartIdentifier()).map(_.getText).getOrElse(ctx.stringLit().getText) } override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { @@ -194,9 +198,10 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { val conditionName = ctx.multipartIdentifier().getText - val conditionValue = Option(ctx.SQLSTATE()).map(_.getText).getOrElse("45000") +// val conditionValue = Option(ctx.stringLit()).map(_.getText).getOrElse("45000") - ErrorCondition(conditionName, conditionValue) +// ErrorCondition(conditionName, conditionValue.asInstanceOf[String]) + ErrorCondition(conditionName, "20000") } override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { @@ -205,7 +210,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val body = Option(ctx.compoundBody()).map(visit).getOrElse { val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), - java.util.UUID.randomUUID.toString) + java.util.UUID.randomUUID.toString, Seq()) }.asInstanceOf[CompoundBody] val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 8e4f902d55a26..5990ed9d8e51c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -236,17 +236,18 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val tree = parseScript(sqlScriptText) assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ErrorCondition]) - assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) +// assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) } test("declare handler") { val sqlScriptText = """ |BEGIN + | SELECT 1; | DECLARE CONTINUE HANDLER FOR test BEGIN SELECT 1; END; |END""".stripMargin val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) + assert(tree.handlers.length == 1) assert(tree.collection.head.isInstanceOf[ErrorHandler]) } @@ -257,8 +258,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | DECLARE CONTINUE HANDLER FOR test SELECT 1; |END""".stripMargin val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ErrorHandler]) + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ErrorHandler]) } // Helper methods From 1e453730892ad4d7d0799e9999d3890c92133755 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 15:20:45 +0200 Subject: [PATCH 32/99] Fix AstBuilder --- .../sql/catalyst/parser/AstBuilder.scala | 26 ++++++++----------- .../parser/SqlScriptingParserSuite.scala | 17 +++++++++--- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2b76876c8538..9b807bab1cce6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -122,8 +122,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(s).asInstanceOf[CompoundBody] }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), - Some(java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)), Seq()) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), None, Seq()) } } @@ -137,14 +136,12 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val buff = ListBuffer[CompoundPlanStatement]() val handlers = ListBuffer[ErrorHandler]() ctx.compoundStatements.forEach(compoundStatement => { - Option(compoundStatement.declareHandler()).map(visit). - foreach(handlers += _.asInstanceOf[ErrorHandler]) - Option(compoundStatement.declareCondition()).map(visit). - foreach(buff += _.asInstanceOf[ErrorCondition]) - Option(compoundStatement.statement()).map(visit). - foreach(buff += _.asInstanceOf[CompoundPlanStatement]) - Option(compoundStatement.beginEndCompoundBlock()).map(visit). - foreach(buff += _.asInstanceOf[CompoundPlanStatement]) + val stmt = visit(compoundStatement).asInstanceOf[CompoundPlanStatement] + + stmt match { + case handler: ErrorHandler => handlers += handler + case s => buff += s + } }) CompoundBody(buff.toSeq, label, handlers.toSeq) } @@ -203,10 +200,10 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { val conditionName = ctx.multipartIdentifier().getText -// val conditionValue = Option(ctx.stringLit()).map(_.getText).getOrElse("45000") + val conditionValue = Option(ctx.stringLit()).map(_.getText).getOrElse("'45000'"). + replace("'", "") -// ErrorCondition(conditionName, conditionValue.asInstanceOf[String]) - ErrorCondition(conditionName, "20000") + ErrorCondition(conditionName, conditionValue) } override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { @@ -214,8 +211,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val body = Option(ctx.compoundBody()).map(visit).getOrElse { val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), - java.util.UUID.randomUUID.toString, Seq()) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), None, Seq()) }.asInstanceOf[CompoundBody] val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 7e9a0adf0e603..97c884c7e5c42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -268,19 +268,30 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val tree = parseScript(sqlScriptText) assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ErrorCondition]) -// assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) + assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) + } + + test("declare condition: custom sqlstate") { + val sqlScriptText = + """ + |BEGIN + | DECLARE test CONDITION FOR '12000'; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ErrorCondition]) + assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("12000")) } test("declare handler") { val sqlScriptText = """ |BEGIN - | SELECT 1; | DECLARE CONTINUE HANDLER FOR test BEGIN SELECT 1; END; |END""".stripMargin val tree = parseScript(sqlScriptText) assert(tree.handlers.length == 1) - assert(tree.collection.head.isInstanceOf[ErrorHandler]) + assert(tree.handlers.head.isInstanceOf[ErrorHandler]) } test("declare handler single statement") { From 75bfb62330ab61eb948cf1a731b3c8d46d3486f3 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 15:32:42 +0200 Subject: [PATCH 33/99] Add check for sqlstate format --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9b807bab1cce6..841bde2be072f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -203,6 +203,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val conditionValue = Option(ctx.stringLit()).map(_.getText).getOrElse("'45000'"). replace("'", "") + val sqlStateRegex = "^[A-Za-z0-9]{5}$".r + assert(sqlStateRegex.findFirstIn(conditionValue).isDefined) + ErrorCondition(conditionName, conditionValue) } From 428570b90ff9382dd1671dc933472c12e3b32ed9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 15:50:34 +0200 Subject: [PATCH 34/99] Add interpreterBuilder to SparkSessionExtensions --- .../apache/spark/sql/SparkSessionExtensions.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 677dba0082575..fbb37fe57e241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import scala.collection.mutable - import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} +import org.apache.spark.sql.scripting.SqlScriptingInterpreter /** * :: Experimental :: @@ -110,6 +110,7 @@ class SparkSessionExtensions { type CheckRuleBuilder = SparkSession => LogicalPlan => Unit type StrategyBuilder = SparkSession => Strategy type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + type InterpreterBuilder = (SparkSession, SqlScriptingInterpreter) => SqlScriptingInterpreter type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) type ColumnarRuleBuilder = SparkSession => ColumnarRule @@ -330,6 +331,16 @@ class SparkSessionExtensions { } } + private[this] val interpreterBuilders = mutable.Buffer.empty[InterpreterBuilder] + + private[sql] def buildInterpreter( + session: SparkSession, + initial: SqlScriptingInterpreter): SqlScriptingInterpreter = { + interpreterBuilders.foldLeft(initial) { (interpreter, builder) => + builder(session, interpreter) + } + } + /** * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session * and an initial parser. The latter allows for a user to create a partial parser and to delegate From 0571f3a8cd6c65691f9b250bfb55b4ab68c67aa2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 17:08:42 +0200 Subject: [PATCH 35/99] Include SparkSession into interpreter --- .../scripting/SqlScriptingExecutionNode.scala | 40 +++++++++++++++++-- .../scripting/SqlScriptingInterpreter.scala | 5 ++- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 51e9304297f4d..119cba581da61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -80,6 +81,16 @@ class SingleStatementExec( */ var isExecuted = false + /** + * Whether an error was raised during the execution of this statement. + */ + var raisedError = false + + /** + * Whether the statement result should be collected in the final result. + */ + var collectResult = false + /** * Get the SQL query text corresponding to this statement. * @return @@ -91,6 +102,24 @@ class SingleStatementExec( } override def reset(): Unit = isExecuted = false + + def execute(session: SparkSession): Option[Seq[Row]] = { + if (!isExecuted) { + if (collectResult) { + try { + return Some(Dataset.ofRows(session, parsedPlan).collect()) + } catch { + case e: Exception => + raisedError = true + // TODO: check handlers for error conditions + logError(s"Error executing statement: ${getText}", e) + return None + } + } + Dataset.ofRows(session, parsedPlan).collect() + } + None + } } /** @@ -99,7 +128,9 @@ class SingleStatementExec( * @param collection * Collection of child execution nodes. */ -abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) +abstract class CompoundNestedStatementIteratorExec( + collection: Seq[CompoundStatementExec], + session: SparkSession) extends NonLeafStatementExec { private var localIterator = collection.iterator @@ -123,8 +154,9 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") - case Some(statement: LeafStatementExec) => + case Some(statement: SingleStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None + statement.execute(session) statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { @@ -153,5 +185,5 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * @param statements * Executable nodes for nested statements within the CompoundBody. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) +class CompoundBodyExec(statements: Seq[CompoundStatementExec], session: SparkSession) + extends CompoundNestedStatementIteratorExec(statements, session) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 6453b204e1623..40eff0d657c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} @@ -25,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.Origin /** * SQL scripting interpreter - builds SQL script execution plan. */ -case class SqlScriptingInterpreter() { +case class SqlScriptingInterpreter(session: SparkSession) { /** * Build execution plan and return statements that need to be executed, @@ -73,7 +74,7 @@ case class SqlScriptingInterpreter() { .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse new CompoundBodyExec( - body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables) + body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, session) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, From f65ecd7d6f83b2067a7048456091f99e81ed0007 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 11 Jul 2024 17:52:02 +0200 Subject: [PATCH 36/99] Move some things to concrete class --- .../scripting/SqlScriptingExecutionNode.scala | 72 +++++++++++-------- .../scripting/SqlScriptingInterpreter.scala | 19 ++++- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 119cba581da61..37d99e9e91830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.network.shuffle.ErrorHandler import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -102,24 +102,6 @@ class SingleStatementExec( } override def reset(): Unit = isExecuted = false - - def execute(session: SparkSession): Option[Seq[Row]] = { - if (!isExecuted) { - if (collectResult) { - try { - return Some(Dataset.ofRows(session, parsedPlan).collect()) - } catch { - case e: Exception => - raisedError = true - // TODO: check handlers for error conditions - logError(s"Error executing statement: ${getText}", e) - return None - } - } - Dataset.ofRows(session, parsedPlan).collect() - } - None - } } /** @@ -128,15 +110,14 @@ class SingleStatementExec( * @param collection * Collection of child execution nodes. */ -abstract class CompoundNestedStatementIteratorExec( - collection: Seq[CompoundStatementExec], - session: SparkSession) +abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) extends NonLeafStatementExec { - private var localIterator = collection.iterator - private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + protected var localIterator: Iterator[CompoundStatementExec] = collection.iterator + protected var curr: Option[CompoundStatementExec] = + if (localIterator.hasNext) Some(localIterator.next()) else None - private lazy val treeIterator: Iterator[CompoundStatementExec] = + protected lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { val childHasNext = curr match { @@ -156,7 +137,6 @@ abstract class CompoundNestedStatementIteratorExec( "No more elements to iterate through in the current SQL compound statement.") case Some(statement: SingleStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None - statement.execute(session) statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { @@ -185,5 +165,41 @@ abstract class CompoundNestedStatementIteratorExec( * @param statements * Executable nodes for nested statements within the CompoundBody. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec], session: SparkSession) - extends CompoundNestedStatementIteratorExec(statements, session) +class CompoundBodyExec(statements: Seq[CompoundStatementExec], handlers: Seq[ErrorHandler]) + extends CompoundNestedStatementIteratorExec(statements) { + + override protected lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = { + val childHasNext = curr match { + case Some(body: NonLeafStatementExec) => body.getTreeIterator.hasNext + case Some(_: LeafStatementExec) => true + case None => false + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") + } + localIterator.hasNext || childHasNext + } + + @scala.annotation.tailrec + override def next(): CompoundStatementExec = { + curr match { + case None => throw SparkException.internalError( + "No more elements to iterate through in the current SQL compound statement.") + case Some(statement: SingleStatementExec) => + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + statement + case Some(body: NonLeafStatementExec) => + if (body.getTreeIterator.hasNext) { + body.getTreeIterator.next() + } else { + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + next() + } + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") + } + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 40eff0d657c9e..9e6f3ec4a0dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} @@ -81,4 +81,21 @@ case class SqlScriptingInterpreter(session: SparkSession) { sparkStatement.origin, isInternal = false) } + + def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { + executionPlan.flatMap { + case statement: SingleStatementExec if !statement.isExecuted => + try { + statement.isExecuted = true + val result = Some(Dataset.ofRows(session, statement.parsedPlan).collect()) + if (statement.collectResult) result else None + } catch { + case e: Exception => + // TODO: check handlers for error conditions + statement.raisedError = true + None + } + case _ => None + } + } } From 46e41e8c12fb793f8f3e4b81709fe2477828690c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Jul 2024 12:02:54 +0200 Subject: [PATCH 37/99] Add handler execution node --- .../scripting/SqlScriptingExecutionNode.scala | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 37d99e9e91830..3257bca143dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.ErrorHandler +import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -91,6 +93,11 @@ class SingleStatementExec( */ var collectResult = false + /** + * Data returned after execution. + */ + var data: Option[Array[Row]] = None + /** * Get the SQL query text corresponding to this statement. * @return @@ -102,6 +109,18 @@ class SingleStatementExec( } override def reset(): Unit = isExecuted = false + + def execute(session: SparkSession): Unit = { + try { + isExecuted = true + val result = Some(Dataset.ofRows(session, parsedPlan).collect()) + if (collectResult) data = result + } catch { + case e: Exception => + // TODO: check handlers for error conditions + raisedError = true + } + } } /** @@ -165,7 +184,10 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * @param statements * Executable nodes for nested statements within the CompoundBody. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec], handlers: Seq[ErrorHandler]) +class CompoundBodyExec( + statements: Seq[CompoundStatementExec], + handlers: Seq[ErrorHandler], + session: SparkSession) extends CompoundNestedStatementIteratorExec(statements) { override protected lazy val treeIterator: Iterator[CompoundStatementExec] = @@ -188,6 +210,10 @@ class CompoundBodyExec(statements: Seq[CompoundStatementExec], handlers: Seq[Err "No more elements to iterate through in the current SQL compound statement.") case Some(statement: SingleStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None + statement.execute(session) + if (statement.raisedError) { + + } statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { @@ -203,3 +229,17 @@ class CompoundBodyExec(statements: Seq[CompoundStatementExec], handlers: Seq[Err } } + +class ErrorHandlerExec( + conditions: Seq[String], + body: CompoundBodyExec, + handlerType: HandlerType) extends CompoundStatementExec { + + def execute(): Unit = { + val iterator = body.getTreeIterator + + while (iterator.hasNext) iterator.next() + } + + override def reset(): Unit = body.reset() +} From b9d74b98b3ef06952755676ef55b10ddd352a3a9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Jul 2024 16:06:39 +0200 Subject: [PATCH 38/99] Add error checking --- .../sql/catalyst/parser/AstBuilder.scala | 18 +++++++---- .../parser/SqlScriptingLogicalOperators.scala | 5 +++- .../scripting/SqlScriptingExecutionNode.scala | 27 +++++++++++++++-- .../scripting/SqlScriptingInterpreter.scala | 30 +++++++++++++++++-- 4 files changed, 69 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 841bde2be072f..e655ac17636b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -19,15 +19,12 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit - import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} - import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} - import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION @@ -47,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, Inte import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt import org.apache.spark.sql.internal.SQLConf @@ -58,6 +55,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.random.RandomSampler +import scala.collection.mutable + /** * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. @@ -135,15 +134,24 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { label: Option[String]): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() val handlers = ListBuffer[ErrorHandler]() + val conditions = mutable.HashMap[String, String]() + val sqlstates = mutable.Set[String]() + ctx.compoundStatements.forEach(compoundStatement => { val stmt = visit(compoundStatement).asInstanceOf[CompoundPlanStatement] stmt match { case handler: ErrorHandler => handlers += handler + case condition: ErrorCondition => + assert(!conditions.contains(condition.conditionName)) // Check for duplicate names. + assert(!sqlstates.contains(condition.value)) // Check for duplicate sqlstates. + conditions += condition.conditionName -> condition.value + sqlstates += condition.value case s => buff += s } }) - CompoundBody(buff.toSeq, label, handlers.toSeq) + + CompoundBody(buff.toSeq, label, handlers.toSeq, conditions) } override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 6b94c0cc6f118..9171e61fb6598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import scala.collection.mutable + /** * Trait for all SQL Scripting logical operators that are product of parsing phase. * These operators will be used by the SQL Scripting interpreter to generate execution nodes. @@ -59,7 +61,8 @@ case class SingleStatement(parsedPlan: LogicalPlan) case class CompoundBody( collection: Seq[CompoundPlanStatement], label: Option[String], - handlers: Seq[ErrorHandler]) extends CompoundPlanStatement + handlers: Seq[ErrorHandler], + conditions: mutable.HashMap[String, String]) extends CompoundPlanStatement /** * Logical operator for an error condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 3257bca143dd8..c9224da69393c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable + import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.network.shuffle.ErrorHandler -import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.parser.HandlerType +import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -98,6 +100,11 @@ class SingleStatementExec( */ var data: Option[Array[Row]] = None + /** + * Error state of the statement. + */ + var errorState: Option[String] = None + /** * Get the SQL query text corresponding to this statement. * @return @@ -186,10 +193,15 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState */ class CompoundBodyExec( statements: Seq[CompoundStatementExec], - handlers: Seq[ErrorHandler], + handlers: Seq[ErrorHandlerExec], + conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec], session: SparkSession) extends CompoundNestedStatementIteratorExec(statements) { + private def getHandler(condition: String): Option[ErrorHandlerExec] = { + conditionHandlerMap.get(condition) + } + override protected lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { @@ -212,7 +224,14 @@ class CompoundBodyExec( curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) if (statement.raisedError) { + val handler = getHandler(statement.errorState.get).get + handler.execute() + handler.reset() + if (handler.getHandlerType == HandlerType.EXIT) { + // TODO: premature exit from the compound ... + curr = None + } } statement case Some(body: NonLeafStatementExec) => @@ -235,6 +254,8 @@ class ErrorHandlerExec( body: CompoundBodyExec, handlerType: HandlerType) extends CompoundStatementExec { + def getHandlerType: HandlerType = handlerType + def execute(): Unit = { val iterator = body.getTreeIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 9e6f3ec4a0dec..d6468ad9e337c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorHandler, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + /** * SQL scripting interpreter - builds SQL script execution plan. */ @@ -73,13 +76,36 @@ case class SqlScriptingInterpreter(session: SparkSession) { .map(varName => DropVariable(varName, ifExists = true)) .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse + + val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() + val handlers = ListBuffer[ErrorHandlerExec]() + body.handlers.foreach(handler => { + val handlerBodyExec = transformTreeIntoExecutable(handler.body). + asInstanceOf[CompoundBodyExec] + val handlerExec = + new ErrorHandlerExec(handler.conditions, handlerBodyExec, handler.handlerType) + + handler.conditions.foreach(condition => { + val conditionValue = body.conditions.getOrElse(condition, condition) + conditionHandlerMap.put(conditionValue, handlerExec) + }) + + handlers += handlerExec + }) + new CompoundBodyExec( - body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, session) + body.collection. + map(st => transformTreeIntoExecutable(st)) ++ dropVariables, + handlers.toSeq, conditionHandlerMap, session) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, sparkStatement.origin, isInternal = false) + case handler: ErrorHandler => + val handlerBodyExec = transformTreeIntoExecutable(handler.body). + asInstanceOf[CompoundBodyExec] + new ErrorHandlerExec(handler.conditions, handlerBodyExec, handler.handlerType) } def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { From 3c4c9de54be3ebf9ea3dc9b9356f9daca3dd7a42 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Jul 2024 14:13:39 +0200 Subject: [PATCH 39/99] Add default values in constructor --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../sql/catalyst/parser/SqlScriptingLogicalOperators.scala | 6 +++--- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 6 +++--- .../spark/sql/scripting/SqlScriptingInterpreter.scala | 5 ++++- .../sql/scripting/SqlScriptingExecutionNodeSuite.scala | 3 ++- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 2 +- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e655ac17636b1..e25ef1d4fbbaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -121,7 +121,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { visit(s).asInstanceOf[CompoundBody] }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), None, Seq()) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan))) } } @@ -222,7 +222,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val body = Option(ctx.compoundBody()).map(visit).getOrElse { val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] - CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), None, Seq()) + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan))) }.asInstanceOf[CompoundBody] val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 9171e61fb6598..b7e1ab6e355e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -60,9 +60,9 @@ case class SingleStatement(parsedPlan: LogicalPlan) */ case class CompoundBody( collection: Seq[CompoundPlanStatement], - label: Option[String], - handlers: Seq[ErrorHandler], - conditions: mutable.HashMap[String, String]) extends CompoundPlanStatement + label: Option[String] = None, + handlers: Seq[ErrorHandler] = Seq.empty, + conditions: mutable.HashMap[String, String] = mutable.HashMap()) extends CompoundPlanStatement /** * Logical operator for an error condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index c9224da69393c..ae930cb03c4bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -193,9 +193,9 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState */ class CompoundBodyExec( statements: Seq[CompoundStatementExec], - handlers: Seq[ErrorHandlerExec], - conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec], - session: SparkSession) + handlers: Seq[ErrorHandlerExec] = Seq.empty, + conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(), + session: SparkSession = null) extends CompoundNestedStatementIteratorExec(statements) { private def getHandler(condition: String): Option[ErrorHandlerExec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index d6468ad9e337c..e82e4807eae4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorHandler, SingleStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorCondition, ErrorHandler, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -106,6 +106,9 @@ case class SqlScriptingInterpreter(session: SparkSession) { val handlerBodyExec = transformTreeIntoExecutable(handler.body). asInstanceOf[CompoundBodyExec] new ErrorHandlerExec(handler.conditions, handlerBodyExec, handler.handlerType) + case condition: ErrorCondition => + throw new UnsupportedOperationException( + s"Error condition $condition is not supported in the execution plan.") } def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 003a8061a604b..9d5120674d7cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -36,7 +36,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) extends CompoundNestedStatementIteratorExec(statements) - case class TestBody(statements: Seq[CompoundStatementExec]) extends CompoundBodyExec(statements) + case class TestBody(statements: Seq[CompoundStatementExec]) + extends CompoundBodyExec(statements) case class TestSparkStatementWithPlan(testVal: String) extends SingleStatementExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index fc1a630b07523..357d08eca09b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Helpers private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { - val interpreter = SqlScriptingInterpreter() + val interpreter = SqlScriptingInterpreter(spark) val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) val executionPlan = interpreter.buildExecutionPlan(compoundBody) val result = executionPlan.flatMap { From 874ed32bcc34cdb45e5864b5f959e64607260103 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Jul 2024 14:26:42 +0200 Subject: [PATCH 40/99] Revert empty lines in imports --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e25ef1d4fbbaa..f21db025ec163 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit + import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} + import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} + import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION @@ -44,7 +47,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, Inte import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt import org.apache.spark.sql.internal.SQLConf From 5d9c4d193afe1648f2856d3c17e46c8bd66a43dc Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Jul 2024 14:27:30 +0200 Subject: [PATCH 41/99] Fix imports --- .../main/scala/org/apache/spark/sql/SparkSessionExtensions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index fbb37fe57e241..93a00521eef86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.mutable + import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} From 41df74e961dc13deebded7359f3d56d772420a88 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Jul 2024 14:22:02 +0200 Subject: [PATCH 42/99] Add script execution in sql API --- .../org/apache/spark/sql/SparkSession.scala | 61 +++++++++++++++---- .../spark/sql/internal/SessionState.scala | 2 + .../scripting/SqlScriptingInterpreter.scala | 13 +--- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4e9dcdb0f3af9..bd8db3c188c4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -40,7 +40,9 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.parser.{CompoundBody, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner @@ -632,6 +634,12 @@ class SparkSession private( | Everything else | * ----------------- */ + private def executeScript(compoundBody: CompoundBody): Iterator[Array[Row]] = { + val interpreter = sessionState.sqlScriptingInterpreter + val executionPlan = interpreter.buildExecutionPlan(compoundBody) + interpreter.execute(executionPlan) + } + /** * Executes a SQL query substituting positional parameters by the given arguments, * returning the result as a `DataFrame`. @@ -650,16 +658,31 @@ class SparkSession private( private[sql] def sql(sqlText: String, args: Array[_], tracker: QueryPlanningTracker): DataFrame = withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { - val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - if (args.nonEmpty) { - PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq) - } else { - parsedPlan + val parsedPlan = sessionState.sqlParser.parseScript(sqlText) + parsedPlan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) if args.nonEmpty => + CompoundBody(List(SingleStatement( + PosParameterizedQuery( + singleStmtPlan.parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)))) + case p => + assert(args.isEmpty, "Named parameters are not supported for batch queries") + p } } - Dataset.ofRows(self, plan, tracker) + + plan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) => + Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) + case _ => + // execute the plan directly if it is not a single statement + val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) + val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow)) + } } + + /** * Executes a SQL query substituting positional parameters by the given arguments, * returning the result as a `DataFrame`. @@ -703,14 +726,28 @@ class SparkSession private( tracker: QueryPlanningTracker): DataFrame = withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { - val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - if (args.nonEmpty) { - NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr)) - } else { - parsedPlan + val parsedPlan = sessionState.sqlParser.parseScript(sqlText) + parsedPlan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) if args.nonEmpty => + CompoundBody(List(SingleStatement( + NameParameterizedQuery( + singleStmtPlan.parsedPlan, args.transform((_, v) => lit(v).expr)))) + ) + case p => + assert(args.isEmpty, "Positional parameters are not supported for batch queries") + p } } - Dataset.ofRows(self, plan, tracker) + + plan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) => + Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) + case _ => + // execute the plan directly if it is not a single statement + val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) + val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index bc6710e6cbdb8..4548da4ed2842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder import org.apache.spark.sql.execution.datasources.DataSourceManager +import org.apache.spark.sql.scripting.SqlScriptingInterpreter import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} @@ -80,6 +81,7 @@ private[sql] class SessionState( val dataSourceRegistration: DataSourceRegistration, catalogBuilder: () => SessionCatalog, val sqlParser: ParserInterface, + val sqlScriptingInterpreter: SqlScriptingInterpreter, analyzerBuilder: () => Analyzer, optimizerBuilder: () => Optimizer, val planner: SparkPlanner, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index e82e4807eae4a..f876579fe0465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -113,17 +113,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { executionPlan.flatMap { - case statement: SingleStatementExec if !statement.isExecuted => - try { - statement.isExecuted = true - val result = Some(Dataset.ofRows(session, statement.parsedPlan).collect()) - if (statement.collectResult) result else None - } catch { - case e: Exception => - // TODO: check handlers for error conditions - statement.raisedError = true - None - } + case statement: SingleStatementExec if !statement.isExecuted && statement.collectResult => + statement.data case _ => None } } From 7557dd2058e73d213020b636bef8e125d1037e3f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Jul 2024 16:32:12 +0200 Subject: [PATCH 43/99] Add script execution to sql API --- .../spark/sql/catalyst/parser/AstBuilder.scala | 3 +-- .../parser/SqlScriptingLogicalOperators.scala | 4 ++-- .../scala/org/apache/spark/sql/SparkSession.scala | 4 ++-- .../spark/sql/internal/BaseSessionStateBuilder.scala | 11 +++++++++++ .../sql/scripting/SqlScriptingExecutionNode.scala | 8 ++------ .../sql/scripting/SqlScriptingInterpreter.scala | 12 ++++++------ 6 files changed, 24 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f21db025ec163..7a19439b1c6b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} @@ -58,8 +59,6 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.random.RandomSampler -import scala.collection.mutable - /** * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index b7e1ab6e355e2..a638da801e9cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.parser +import scala.collection.mutable + import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} -import scala.collection.mutable - /** * Trait for all SQL Scripting logical operators that are product of parsing phase. * These operators will be used by the SQL Scripting interpreter to generate execution nodes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index bd8db3c188c4b..4cb555f79a952 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -677,7 +677,7 @@ class SparkSession private( // execute the plan directly if it is not a single statement val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) - Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow.toIndexedSeq)) } } @@ -746,7 +746,7 @@ class SparkSession private( // execute the plan directly if it is not a single statement val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) - Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow.toIndexedSeq)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4660970814e21..23c12fd53ddce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.scripting.SqlScriptingInterpreter import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -146,6 +147,15 @@ abstract class BaseSessionStateBuilder( extensions.buildParser(session, new SparkSqlParser()) } + /** + * Script interpreter that produces execution plan for sql batch procedural language. + * + * Note: this depends on the `conf` field. + */ + protected lazy val scriptingInterpreter: SqlScriptingInterpreter = { + extensions.buildInterpreter(session, SqlScriptingInterpreter(session)) + } + /** * ResourceLoader that is used to load function resources and jars. */ @@ -396,6 +406,7 @@ abstract class BaseSessionStateBuilder( dataSourceRegistration, () => catalog, sqlParser, + scriptingInterpreter, () => analyzer, () => optimizer, planner, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index ae930cb03c4bc..2b29afe171cc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -76,7 +76,8 @@ trait NonLeafStatementExec extends CompoundStatementExec { class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, - override val isInternal: Boolean) + override val isInternal: Boolean, + var collectResult: Boolean = true) // Whether the statement result should be collected in the final result. extends LeafStatementExec with WithOrigin { /** @@ -90,11 +91,6 @@ class SingleStatementExec( */ var raisedError = false - /** - * Whether the statement result should be collected in the final result. - */ - var collectResult = false - /** * Data returned after execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index f876579fe0465..d1b5bc11e0ba1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.{Dataset, Row, SparkSession} +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorCondition, ErrorHandler, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - /** * SQL scripting interpreter - builds SQL script execution plan. */ @@ -74,7 +74,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { } val dropVariables = variables .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), isInternal = true)) + .map(new SingleStatementExec(_, Origin(), isInternal = true, collectResult = false)) .reverse val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() @@ -113,7 +113,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { executionPlan.flatMap { - case statement: SingleStatementExec if !statement.isExecuted && statement.collectResult => + case statement: SingleStatementExec if statement.collectResult => statement.data case _ => None } From 02579c90fef1c0d75cfbac812e0225c68e12c51d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Jul 2024 17:15:00 +0200 Subject: [PATCH 44/99] Testing with print --- .../scripting/SqlScriptingExecutionNode.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 2b29afe171cc5..a41dd6d2b478e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.scripting import scala.collection.mutable - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.parser.HandlerType @@ -77,7 +76,7 @@ class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, override val isInternal: Boolean, - var collectResult: Boolean = true) // Whether the statement result should be collected in the final result. + var collectResult: Boolean = true) // Whether the statement result should be collected extends LeafStatementExec with WithOrigin { /** @@ -116,12 +115,19 @@ class SingleStatementExec( def execute(session: SparkSession): Unit = { try { isExecuted = true + print("EXECUTING\n\n\n") val result = Some(Dataset.ofRows(session, parsedPlan).collect()) if (collectResult) data = result } catch { - case e: Exception => + case e: SparkThrowable => // TODO: check handlers for error conditions raisedError = true + errorState = Some(e.getSqlState) + print(s"\n\n\nError raised: ${e.getSqlState}\n\n") + case _: Throwable => + print("\n\n\nError raised: UNKNOWN\n\n") + raisedError = true + errorState = Some("UNKNOWN") } } } @@ -220,6 +226,7 @@ class CompoundBodyExec( curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) if (statement.raisedError) { + print(s"Error raised: ${statement.errorState.get}\n\n") val handler = getHandler(statement.errorState.get).get handler.execute() handler.reset() @@ -253,6 +260,7 @@ class ErrorHandlerExec( def getHandlerType: HandlerType = handlerType def execute(): Unit = { + print("\n\n\nHANDLER\n\n\n") val iterator = body.getTreeIterator while (iterator.hasNext) iterator.next() From f2fb4709a530b6416b0b8ae4dbd8a24bbfccbc37 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Jul 2024 19:00:22 +0200 Subject: [PATCH 45/99] Clean debugging stugg --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 7 +++---- .../spark/sql/scripting/SqlScriptingInterpreter.scala | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index a41dd6d2b478e..beae3996c3922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.scripting import scala.collection.mutable + import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} @@ -115,7 +116,6 @@ class SingleStatementExec( def execute(session: SparkSession): Unit = { try { isExecuted = true - print("EXECUTING\n\n\n") val result = Some(Dataset.ofRows(session, parsedPlan).collect()) if (collectResult) data = result } catch { @@ -123,9 +123,7 @@ class SingleStatementExec( // TODO: check handlers for error conditions raisedError = true errorState = Some(e.getSqlState) - print(s"\n\n\nError raised: ${e.getSqlState}\n\n") case _: Throwable => - print("\n\n\nError raised: UNKNOWN\n\n") raisedError = true errorState = Some("UNKNOWN") } @@ -226,7 +224,6 @@ class CompoundBodyExec( curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) if (statement.raisedError) { - print(s"Error raised: ${statement.errorState.get}\n\n") val handler = getHandler(statement.errorState.get).get handler.execute() handler.reset() @@ -259,6 +256,8 @@ class ErrorHandlerExec( def getHandlerType: HandlerType = handlerType + def getHandlerBody: CompoundBodyExec = body + def execute(): Unit = { print("\n\n\nHANDLER\n\n\n") val iterator = body.getTreeIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index d1b5bc11e0ba1..aafd9ef1c58bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -113,8 +113,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { executionPlan.flatMap { - case statement: SingleStatementExec if statement.collectResult => - statement.data + case statement: SingleStatementExec if statement.collectResult => statement.data case _ => None } } From 6f7189681fe2b4b73a62b6f8199367ac787e4d3e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 17 Jul 2024 17:31:15 +0200 Subject: [PATCH 46/99] Add comment and remove unused parameters from constructors --- .../scripting/SqlScriptingExecutionNode.scala | 18 ++++++++---------- .../scripting/SqlScriptingInterpreter.scala | 15 ++++++--------- .../SqlScriptingExecutionNodeSuite.scala | 2 +- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index beae3996c3922..aed4069676132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -192,8 +192,8 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * Executable nodes for nested statements within the CompoundBody. */ class CompoundBodyExec( + label: Option[String], statements: Seq[CompoundStatementExec], - handlers: Seq[ErrorHandlerExec] = Seq.empty, conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(), session: SparkSession = null) extends CompoundNestedStatementIteratorExec(statements) { @@ -225,12 +225,10 @@ class CompoundBodyExec( statement.execute(session) if (statement.raisedError) { val handler = getHandler(statement.errorState.get).get - handler.execute() - handler.reset() - + handler.executeAndReset() if (handler.getHandlerType == HandlerType.EXIT) { // TODO: premature exit from the compound ... - curr = None + curr = None // throws error because of none } } statement @@ -250,18 +248,18 @@ class CompoundBodyExec( } class ErrorHandlerExec( - conditions: Seq[String], body: CompoundBodyExec, handlerType: HandlerType) extends CompoundStatementExec { def getHandlerType: HandlerType = handlerType - def getHandlerBody: CompoundBodyExec = body + def executeAndReset(): Unit = { + execute() + reset() + } - def execute(): Unit = { - print("\n\n\nHANDLER\n\n\n") + private def execute(): Unit = { val iterator = body.getTreeIterator - while (iterator.hasNext) iterator.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index aafd9ef1c58bd..5df3114c47437 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorCondition, ErrorHandler, SingleStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -83,7 +83,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { val handlerBodyExec = transformTreeIntoExecutable(handler.body). asInstanceOf[CompoundBodyExec] val handlerExec = - new ErrorHandlerExec(handler.conditions, handlerBodyExec, handler.handlerType) + new ErrorHandlerExec(handlerBodyExec, handler.handlerType) handler.conditions.foreach(condition => { val conditionValue = body.conditions.getOrElse(condition, condition) @@ -94,21 +94,18 @@ case class SqlScriptingInterpreter(session: SparkSession) { }) new CompoundBodyExec( + body.label, body.collection. map(st => transformTreeIntoExecutable(st)) ++ dropVariables, - handlers.toSeq, conditionHandlerMap, session) + conditionHandlerMap, session) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, sparkStatement.origin, isInternal = false) - case handler: ErrorHandler => - val handlerBodyExec = transformTreeIntoExecutable(handler.body). - asInstanceOf[CompoundBodyExec] - new ErrorHandlerExec(handler.conditions, handlerBodyExec, handler.handlerType) - case condition: ErrorCondition => + case _ => throw new UnsupportedOperationException( - s"Error condition $condition is not supported in the execution plan.") + s"Unsupported operation in the execution plan.") } def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 9d5120674d7cc..bbce5345a6990 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -37,7 +37,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { extends CompoundNestedStatementIteratorExec(statements) case class TestBody(statements: Seq[CompoundStatementExec]) - extends CompoundBodyExec(statements) + extends CompoundBodyExec(None, statements) case class TestSparkStatementWithPlan(testVal: String) extends SingleStatementExec( From 2c1ab0d940d1749de731d1aae77498dbaefe4fbd Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 18 Jul 2024 14:10:31 +0200 Subject: [PATCH 47/99] Add execution functions --- .../org/apache/spark/sql/SparkSession.scala | 9 ++- .../spark/sql/SparkSessionExtensions.scala | 13 +++- .../internal/BaseSessionStateBuilder.scala | 11 +++ .../spark/sql/internal/SessionState.scala | 4 +- .../scripting/SqlScriptingExecutionNode.scala | 75 +++++++++++++++++-- .../scripting/SqlScriptingInterpreter.scala | 15 +++- .../SqlScriptingExecutionNodeSuite.scala | 3 + .../SqlScriptingInterpreterSuite.scala | 2 +- 8 files changed, 116 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4e9dcdb0f3af9..8281bc24610ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,11 +21,9 @@ import java.io.Closeable import java.util.{ServiceLoader, UUID} import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} - import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal - import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD @@ -40,6 +38,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.parser.CompoundBody import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -632,6 +631,12 @@ class SparkSession private( | Everything else | * ----------------- */ + private def executeScript(compoundBody: CompoundBody): Iterator[Array[Row]] = { + val interpreter = sessionState.sqlScriptingInterpreter + val executionPlan = interpreter.buildExecutionPlan(compoundBody) + interpreter.execute(executionPlan) + } + /** * Executes a SQL query substituting positional parameters by the given arguments, * returning the result as a `DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 677dba0082575..fbb37fe57e241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import scala.collection.mutable - import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} +import org.apache.spark.sql.scripting.SqlScriptingInterpreter /** * :: Experimental :: @@ -110,6 +110,7 @@ class SparkSessionExtensions { type CheckRuleBuilder = SparkSession => LogicalPlan => Unit type StrategyBuilder = SparkSession => Strategy type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + type InterpreterBuilder = (SparkSession, SqlScriptingInterpreter) => SqlScriptingInterpreter type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) type ColumnarRuleBuilder = SparkSession => ColumnarRule @@ -330,6 +331,16 @@ class SparkSessionExtensions { } } + private[this] val interpreterBuilders = mutable.Buffer.empty[InterpreterBuilder] + + private[sql] def buildInterpreter( + session: SparkSession, + initial: SqlScriptingInterpreter): SqlScriptingInterpreter = { + interpreterBuilders.foldLeft(initial) { (interpreter, builder) => + builder(session, interpreter) + } + } + /** * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session * and an initial parser. The latter allows for a user to create a partial parser and to delegate diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4660970814e21..23c12fd53ddce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.scripting.SqlScriptingInterpreter import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -146,6 +147,15 @@ abstract class BaseSessionStateBuilder( extensions.buildParser(session, new SparkSqlParser()) } + /** + * Script interpreter that produces execution plan for sql batch procedural language. + * + * Note: this depends on the `conf` field. + */ + protected lazy val scriptingInterpreter: SqlScriptingInterpreter = { + extensions.buildInterpreter(session, SqlScriptingInterpreter(session)) + } + /** * ResourceLoader that is used to load function resources and jars. */ @@ -396,6 +406,7 @@ abstract class BaseSessionStateBuilder( dataSourceRegistration, () => catalog, sqlParser, + scriptingInterpreter, () => analyzer, () => optimizer, planner, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index bc6710e6cbdb8..1cf6e972ecb80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql.internal import java.io.File import java.net.URI - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path - import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager @@ -36,6 +34,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder import org.apache.spark.sql.execution.datasources.DataSourceManager +import org.apache.spark.sql.scripting.SqlScriptingInterpreter import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} @@ -80,6 +79,7 @@ private[sql] class SessionState( val dataSourceRegistration: DataSourceRegistration, catalogBuilder: () => SessionCatalog, val sqlParser: ParserInterface, + val sqlScriptingInterpreter: SqlScriptingInterpreter, analyzerBuilder: () => Analyzer, optimizerBuilder: () => Optimizer, val planner: SparkPlanner, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 51e9304297f4d..5505cab91e535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -42,7 +43,14 @@ sealed trait CompoundStatementExec extends Logging { /** * Leaf node in the execution tree. */ -trait LeafStatementExec extends CompoundStatementExec +trait LeafStatementExec extends CompoundStatementExec { + + /** + * Execute the statement. + * @param session Spark session. + */ + def execute(session: SparkSession): Unit +} /** * Non-leaf node in the execution tree. It is an iterator over executable child nodes. @@ -71,7 +79,8 @@ trait NonLeafStatementExec extends CompoundStatementExec { class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, - override val isInternal: Boolean) + override val isInternal: Boolean, + val collectResult: Boolean = true) extends LeafStatementExec with WithOrigin { /** @@ -80,6 +89,11 @@ class SingleStatementExec( */ var isExecuted = false + /** + * Data returned after execution. + */ + var data: Option[Array[Row]] = None + /** * Get the SQL query text corresponding to this statement. * @return @@ -91,6 +105,12 @@ class SingleStatementExec( } override def reset(): Unit = isExecuted = false + + def execute(session: SparkSession): Unit = { + isExecuted = true + val result = Some(Dataset.ofRows(session, parsedPlan).collect()) + if (collectResult) data = result + } } /** @@ -102,10 +122,11 @@ class SingleStatementExec( abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) extends NonLeafStatementExec { - private var localIterator = collection.iterator - private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + protected var localIterator: Iterator[CompoundStatementExec] = collection.iterator + protected var curr: Option[CompoundStatementExec] = + if (localIterator.hasNext) Some(localIterator.next()) else None - private lazy val treeIterator: Iterator[CompoundStatementExec] = + protected lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { val childHasNext = curr match { @@ -152,6 +173,46 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * Executable node for CompoundBody. * @param statements * Executable nodes for nested statements within the CompoundBody. + * @param session + * Spark session. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) +class CompoundBodyExec( + statements: Seq[CompoundStatementExec], + session: SparkSession = null) + extends CompoundNestedStatementIteratorExec(statements) { + + override protected lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = { + val childHasNext = curr match { + case Some(body: NonLeafStatementExec) => body.getTreeIterator.hasNext + case Some(_: LeafStatementExec) => true + case None => false + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") + } + localIterator.hasNext || childHasNext + } + + @scala.annotation.tailrec + override def next(): CompoundStatementExec = { + curr match { + case None => throw SparkException.internalError( + "No more elements to iterate through in the current SQL compound statement.") + case Some(statement: SingleStatementExec) => + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + statement.execute(session) // Execute the leaf statement + statement + case Some(body: NonLeafStatementExec) => + if (body.getTreeIterator.hasNext) { + body.getTreeIterator.next() + } else { + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + next() + } + case _ => throw SparkException.internalError( + "Unknown statement type encountered during SQL script interpretation.") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 6453b204e1623..44bb1fcd1ff6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} @@ -25,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.Origin /** * SQL scripting interpreter - builds SQL script execution plan. */ -case class SqlScriptingInterpreter() { +case class SqlScriptingInterpreter(session: SparkSession) { /** * Build execution plan and return statements that need to be executed, @@ -70,14 +71,22 @@ case class SqlScriptingInterpreter() { } val dropVariables = variables .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), isInternal = true)) + .map(new SingleStatementExec(_, Origin(), isInternal = true, collectResult = false)) .reverse new CompoundBodyExec( - body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables) + body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, session) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, sparkStatement.origin, isInternal = false) } + + def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { + executionPlan.flatMap { + case statement: SingleStatementExec if statement.collectResult + && !statement.isInternal => statement.data + case _ => None + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 003a8061a604b..68d795b5fb4e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin @@ -31,6 +32,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { // Helpers case class TestLeafStatement(testVal: String) extends LeafStatementExec { override def reset(): Unit = () + + override def execute(session: SparkSession): Unit = () } case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index f9722a2c14b40..57d8c1a69f3c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Helpers private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { - val interpreter = SqlScriptingInterpreter() + val interpreter = SqlScriptingInterpreter(spark) val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) val executionPlan = interpreter.buildExecutionPlan(compoundBody) val result = executionPlan.flatMap { From a1028bb0710748c55ebbf1aea063a55c3ddf1502 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 18 Jul 2024 14:28:26 +0200 Subject: [PATCH 48/99] Fix one sql() method --- .../org/apache/spark/sql/SparkSession.scala | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 8281bc24610ad..41e9d6830e43c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.parser.CompoundBody +import org.apache.spark.sql.catalyst.parser.{CompoundBody, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner @@ -655,14 +656,27 @@ class SparkSession private( private[sql] def sql(sqlText: String, args: Array[_], tracker: QueryPlanningTracker): DataFrame = withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { - val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - if (args.nonEmpty) { - PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq) - } else { - parsedPlan + val parsedPlan = sessionState.sqlParser.parseScript(sqlText) + parsedPlan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => + CompoundBody(List(SingleStatement( + PosParameterizedQuery( + singleStmtPlan.parsedPlan, args.map(lit(_).expr).toImmutableArraySeq))), label) + case p => + assert(args.isEmpty, "Named parameters are not supported for batch queries") + p } } - Dataset.ofRows(self, plan, tracker) + + plan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _) => + Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) + case _ => + // execute the plan directly if it is not a single statement + val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) + val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow.toIndexedSeq)) + } } /** From e3f36381c9de9679ad1a7624a8435d338c69fb45 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 18 Jul 2024 14:30:15 +0200 Subject: [PATCH 49/99] Fix other sql() method --- .../org/apache/spark/sql/SparkSession.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 41e9d6830e43c..3498074685465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -722,14 +722,27 @@ class SparkSession private( tracker: QueryPlanningTracker): DataFrame = withActive { val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { - val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) - if (args.nonEmpty) { - NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr)) - } else { - parsedPlan + val parsedPlan = sessionState.sqlParser.parseScript(sqlText) + parsedPlan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => + CompoundBody(List(SingleStatement( + NameParameterizedQuery( + singleStmtPlan.parsedPlan, args.transform((_, v) => lit(v).expr)))), label) + case p => + assert(args.isEmpty, "Positional parameters are not supported for batch queries") + p } } - Dataset.ofRows(self, plan, tracker) + + plan match { + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _) => + Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) + case _ => + // execute the plan directly if it is not a single statement + val lastRow = executeScript(plan).foldLeft(Array.empty[Row])((_, next) => next) + val attributes = DataTypeUtils.toAttributes(lastRow.head.schema) + Dataset.ofRows(self, LocalRelation.fromExternalRows(attributes, lastRow.toIndexedSeq)) + } } /** From b749f95223068505c93f4d11ad1677ec25f16b7f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 13:04:25 +0200 Subject: [PATCH 50/99] Change the execution of handler --- .../scripting/SqlScriptingExecutionNode.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index aed4069676132..15642ea402ca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.parser.HandlerType import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -222,14 +221,12 @@ class CompoundBodyExec( "No more elements to iterate through in the current SQL compound statement.") case Some(statement: SingleStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None - statement.execute(session) + if (!statement.isExecuted) statement.execute(session) if (statement.raisedError) { val handler = getHandler(statement.errorState.get).get - handler.executeAndReset() - if (handler.getHandlerType == HandlerType.EXIT) { - // TODO: premature exit from the compound ... - curr = None // throws error because of none - } + // Reset handler body to execute it again + handler.getHandlerBody.reset() + return handler.getTreeIterator.next() } statement case Some(body: NonLeafStatementExec) => @@ -249,10 +246,14 @@ class CompoundBodyExec( class ErrorHandlerExec( body: CompoundBodyExec, - handlerType: HandlerType) extends CompoundStatementExec { + handlerType: HandlerType) extends NonLeafStatementExec { + + override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator def getHandlerType: HandlerType = handlerType + def getHandlerBody: CompoundBodyExec = body + def executeAndReset(): Unit = { execute() reset() From b6f4740ea44d022578c1c964696d9b92f09e1f28 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 13:38:45 +0200 Subject: [PATCH 51/99] Address comments --- .../spark/sql/SparkSessionExtensions.scala | 1 + .../internal/BaseSessionStateBuilder.scala | 4 +- .../spark/sql/internal/SessionState.scala | 2 + .../scripting/SqlScriptingExecutionNode.scala | 44 +++---------------- .../scripting/SqlScriptingInterpreter.scala | 2 +- .../SqlScriptingExecutionNodeSuite.scala | 3 +- 6 files changed, 14 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index fbb37fe57e241..93a00521eef86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.mutable + import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 23c12fd53ddce..accc82ad39a66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -148,9 +148,7 @@ abstract class BaseSessionStateBuilder( } /** - * Script interpreter that produces execution plan for sql batch procedural language. - * - * Note: this depends on the `conf` field. + * Script interpreter that produces execution plan and executes SQL scripts. */ protected lazy val scriptingInterpreter: SqlScriptingInterpreter = { extensions.buildInterpreter(session, SqlScriptingInterpreter(session)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 1cf6e972ecb80..4548da4ed2842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.internal import java.io.File import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ import org.apache.spark.sql.artifact.ArtifactManager diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 5505cab91e535..7f8457bf42bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -92,7 +92,7 @@ class SingleStatementExec( /** * Data returned after execution. */ - var data: Option[Array[Row]] = None + var result: Option[Array[Row]] = None /** * Get the SQL query text corresponding to this statement. @@ -108,8 +108,10 @@ class SingleStatementExec( def execute(session: SparkSession): Unit = { isExecuted = true - val result = Some(Dataset.ofRows(session, parsedPlan).collect()) - if (collectResult) data = result + val rows = Some(Dataset.ofRows(session, parsedPlan).collect()) + if (collectResult) { + result = rows + } } } @@ -126,39 +128,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState protected var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None - protected lazy val treeIterator: Iterator[CompoundStatementExec] = - new Iterator[CompoundStatementExec] { - override def hasNext: Boolean = { - val childHasNext = curr match { - case Some(body: NonLeafStatementExec) => body.getTreeIterator.hasNext - case Some(_: LeafStatementExec) => true - case None => false - case _ => throw SparkException.internalError( - "Unknown statement type encountered during SQL script interpretation.") - } - localIterator.hasNext || childHasNext - } - - @scala.annotation.tailrec - override def next(): CompoundStatementExec = { - curr match { - case None => throw SparkException.internalError( - "No more elements to iterate through in the current SQL compound statement.") - case Some(statement: LeafStatementExec) => - curr = if (localIterator.hasNext) Some(localIterator.next()) else None - statement - case Some(body: NonLeafStatementExec) => - if (body.getTreeIterator.hasNext) { - body.getTreeIterator.next() - } else { - curr = if (localIterator.hasNext) Some(localIterator.next()) else None - next() - } - case _ => throw SparkException.internalError( - "Unknown statement type encountered during SQL script interpretation.") - } - } - } + protected lazy val treeIterator: Iterator[CompoundStatementExec] = null override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator @@ -178,7 +148,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState */ class CompoundBodyExec( statements: Seq[CompoundStatementExec], - session: SparkSession = null) + session: SparkSession) extends CompoundNestedStatementIteratorExec(statements) { override protected lazy val treeIterator: Iterator[CompoundStatementExec] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 44bb1fcd1ff6f..7b925f622f9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -85,7 +85,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { executionPlan.flatMap { case statement: SingleStatementExec if statement.collectResult - && !statement.isInternal => statement.data + && !statement.isInternal => statement.result case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 68d795b5fb4e5..4855f932561b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -39,7 +39,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) extends CompoundNestedStatementIteratorExec(statements) - case class TestBody(statements: Seq[CompoundStatementExec]) extends CompoundBodyExec(statements) + case class TestBody(statements: Seq[CompoundStatementExec]) + extends CompoundBodyExec(statements, null) case class TestSparkStatementWithPlan(testVal: String) extends SingleStatementExec( From 084596719eece62bc5766131028935ed9eb757e9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 13:54:31 +0200 Subject: [PATCH 52/99] Update execute method to call buildExecutionPlan --- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 5 +++-- .../apache/spark/sql/scripting/SqlScriptingInterpreter.scala | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3498074685465..87d590e66a927 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,9 +21,11 @@ import java.io.Closeable import java.util.{ServiceLoader, UUID} import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} + import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal + import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD @@ -634,8 +636,7 @@ class SparkSession private( private def executeScript(compoundBody: CompoundBody): Iterator[Array[Row]] = { val interpreter = sessionState.sqlScriptingInterpreter - val executionPlan = interpreter.buildExecutionPlan(compoundBody) - interpreter.execute(executionPlan) + interpreter.execute(compoundBody) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 7b925f622f9a1..895c31129e67c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -82,7 +82,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { isInternal = false) } - def execute(executionPlan: Iterator[CompoundStatementExec]): Iterator[Array[Row]] = { + def execute(compoundBody: CompoundBody): Iterator[Array[Row]] = { + val executionPlan = buildExecutionPlan(compoundBody) executionPlan.flatMap { case statement: SingleStatementExec if statement.collectResult && !statement.isInternal => statement.result From e49716e2a8ea2b2641b189c5601afae6687bc5e4 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 14:19:52 +0200 Subject: [PATCH 53/99] Add check for already executed statement --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 4 +++- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7f8457bf42bc7..5926d6ba1e4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -171,7 +171,9 @@ class CompoundBodyExec( "No more elements to iterate through in the current SQL compound statement.") case Some(statement: SingleStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None - statement.execute(session) // Execute the leaf statement + if (!statement.isExecuted) { + statement.execute(session) // Execute the leaf statement + } statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 57d8c1a69f3c4..d1df6241c03a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -35,10 +35,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val executionPlan = interpreter.buildExecutionPlan(compoundBody) val result = executionPlan.flatMap { case statement: SingleStatementExec => - if (statement.isExecuted) { - None - } else { + if (statement.collectResult) { Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } else { + None } case _ => None }.toArray From 03ba9e3888da2ee8caf80b2a624573fc6b714d8e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 14:23:32 +0200 Subject: [PATCH 54/99] Add compound body test in interpreter suite --- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index d1df6241c03a0..4f629f64b6b26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -52,6 +52,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) } + // Tests + test("select 1; select 2;") { + verifySqlScriptResult("BEGIN SELECT 1; SELECT 2; END", Seq(Seq(Row(1)), Seq(Row(2)))) + } + test("multi statement - simple") { withTable("t") { val sqlScript = From 1fd896b9e61351207290dedbd1d3f4da1ec63c8d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 19 Jul 2024 19:20:28 +0200 Subject: [PATCH 55/99] Fix Interpreter test suite --- .../SqlScriptingInterpreterSuite.scala | 75 ++++++++----------- 1 file changed, 31 insertions(+), 44 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 4f629f64b6b26..15c4c1f35a76f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, Row} -import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession /** @@ -29,32 +28,25 @@ import org.apache.spark.sql.test.SharedSparkSession */ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Helpers - private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { + private def verifySqlScriptResult(sqlText: String, expected: Seq[Array[Row]]): Unit = { val interpreter = SqlScriptingInterpreter(spark) val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText) - val executionPlan = interpreter.buildExecutionPlan(compoundBody) - val result = executionPlan.flatMap { - case statement: SingleStatementExec => - if (statement.collectResult) { - Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) - } else { - None - } - case _ => None - }.toArray - + val result = interpreter.execute(compoundBody).toSeq assert(result.length == expected.length) - result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } + result.zip(expected).foreach { + case (actualAnswer, expectedAnswer) => + assert(actualAnswer.sameElements(expectedAnswer)) + } } // Tests test("select 1") { - verifySqlScriptResult("SELECT 1;", Seq(Seq(Row(1)))) + verifySqlScriptResult("SELECT 1;", Seq(Array(Row(1)))) } // Tests test("select 1; select 2;") { - verifySqlScriptResult("BEGIN SELECT 1; SELECT 2; END", Seq(Seq(Row(1)), Seq(Row(2)))) + verifySqlScriptResult("BEGIN SELECT 1; SELECT 2; END", Seq(Array(Row(1)), Array(Row(2)))) } test("multi statement - simple") { @@ -69,10 +61,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // select with filter - Seq(Row(1)) // select + Array.empty[Row], // create table + Array.empty[Row], // insert + Array.empty[Row], // select with filter + Array(Row(1)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -94,10 +86,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert #1 - Seq.empty[Row], // insert #2 - Seq(Row(false)) // select + Array.empty[Row], // create table + Array.empty[Row], // insert #1 + Array.empty[Row], // insert #2 + Array(Row(false)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -113,10 +105,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Array.empty[Row], // declare var + Array.empty[Row], // set var + Array(Row(2)), // select ) verifySqlScriptResult(sqlScript, expected) } @@ -141,16 +132,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(4)), // select - Seq.empty[Row] // drop var + Array.empty[Row], // declare var + Array(Row(1)), // select + Array.empty[Row], // declare var + Array(Row(2)), // select + Array.empty[Row], // declare var + Array.empty[Row], // set var + Array(Row(4)), // select ) verifySqlScriptResult(sqlScript, expected) } @@ -193,11 +181,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - explicit - Seq.empty[Row] // drop var - implicit + Array.empty[Row], // declare var + Array.empty[Row], // set var + Array(Row(2)), // select + Array.empty[Row], // drop var - explicit ) verifySqlScriptResult(sqlScript, expected) } From 458d45fc20e859c152ce0c250a7a6ccdee5e10fb Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 12:21:30 +0200 Subject: [PATCH 56/99] Remove comment --- .../scripting/SqlScriptingInterpreterSuite.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 15c4c1f35a76f..d94ef4a98b0c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -44,9 +44,19 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult("SELECT 1;", Seq(Array(Row(1)))) } - // Tests test("select 1; select 2;") { - verifySqlScriptResult("BEGIN SELECT 1; SELECT 2; END", Seq(Array(Row(1)), Array(Row(2)))) + val sqlScript = + """ + |BEGIN + |SELECT 1; + |SELECT 2; + |END + |""".stripMargin + val expected = Seq( + Array(Row(1)), + Array(Row(2)) + ) + verifySqlScriptResult(sqlScript, expected) } test("multi statement - simple") { From c707776c3c0cecc10b2d1416d577968d6b7e50dc Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 12:47:44 +0200 Subject: [PATCH 57/99] Add sqlScriptingEnabled flag --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f50eb9b121589..d25bb731cfc7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3333,6 +3333,14 @@ object SQLConf { .version("2.3.0") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + val SQL_SCRIPTING_ENABLED = + buildConf("spark.sql.scripting.enabled") + .doc("SQL Scripting feature is under development and its use should be done under this" + + "feature flag.") + .version("4.0.0") + .booleanConf + .createWithDefault(Utils.isTesting) + val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + "an output as binary. Otherwise, it returns as a string.") @@ -5503,6 +5511,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) + def sqlScriptingEnabled: Boolean = getConf(SQL_SCRIPTING_ENABLED) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) From 655543bf8f8a87e900b61ae949a1945f3391a69e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 13:02:14 +0200 Subject: [PATCH 58/99] Add error condition --- python/pyspark/errors/error-conditions.json | 5 +++++ .../org/apache/spark/sql/errors/SqlScriptingErrors.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index dd70e814b1ea8..014a81c662e3b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -908,6 +908,11 @@ "Slice with step is not supported." ] }, + "SQL_SCRIPTING_NOT_ENABLED": { + "message": [ + "SQL Scripting is under development and not all features are supported. To enable existing features set to `true`." + ] + }, "STATE_NOT_EXISTS": { "message": [ "State is either not defined or has already been removed." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index c1ce93e10553b..1cda8f33753cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf /** * Object for grouping error messages thrown during parsing/interpreting phase @@ -39,4 +40,11 @@ private[sql] object SqlScriptingErrors extends QueryErrorsBase { messageParameters = Map("endLabel" -> endLabel)) } + def sqlScriptingNotEnabled(): Throwable = { + new SparkException( + errorClass = "SQL_SCRIPTING_NOT_ENABLED", + cause = null, + messageParameters = Map("sqlScriptingEnabled" -> SQLConf.SQL_SCRIPTING_ENABLED.key)) + } + } From 04d6031b86cca27e8233f19dfb8644059f6dfe0a Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 13:10:44 +0200 Subject: [PATCH 59/99] Move error class to UNSUPPORTED_FEATURE class --- common/utils/src/main/resources/error/error-conditions.json | 5 +++++ .../org/apache/spark/sql/errors/SqlScriptingErrors.scala | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 84681ab8c2253..d0e3faf95b5c0 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4674,6 +4674,11 @@ " is a VARIABLE and cannot be updated using the SET statement. Use SET VARIABLE = ... instead." ] }, + "SQL_SCRIPTING_NOT_ENABLED" : { + "message" : [ + "SQL scripting is under development and not all features are supported. To enable existing features set to `true`." + ] + }, "STATE_STORE_MULTIPLE_COLUMN_FAMILIES" : { "message" : [ "Creating multiple column families with is not supported." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 1cda8f33753cd..33b5166fdfef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -42,7 +42,7 @@ private[sql] object SqlScriptingErrors extends QueryErrorsBase { def sqlScriptingNotEnabled(): Throwable = { new SparkException( - errorClass = "SQL_SCRIPTING_NOT_ENABLED", + errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_NOT_ENABLED", cause = null, messageParameters = Map("sqlScriptingEnabled" -> SQLConf.SQL_SCRIPTING_ENABLED.key)) } From a32df9c75d1a1f872976e562222b82af621277ef Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 14:21:11 +0200 Subject: [PATCH 60/99] Add check for sqlconf --- .../spark/sql/catalyst/parser/AstBuilder.scala | 3 +++ .../parser/SqlScriptingParserSuite.scala | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f5cf3e717a3ce..964d63679d433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -119,6 +119,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitCompoundOrSingleStatement( ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) { Option(ctx.singleCompoundStatement()).map { s => + if (!SQLConf.get.sqlScriptingEnabled) { + throw SqlScriptingErrors.sqlScriptingNotEnabled() + } visit(s).asInstanceOf[CompoundBody] }.getOrElse { val logicalPlan = visitSingleStatement(ctx.singleStatement()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index d4eb5fd747ac4..d0115bab6eaca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -324,4 +324,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .replace("END", "") .trim } + + test("SQL Scripting not enabled") { + val sqlScriptText = + """ + |BEGIN + | DECLARE totalInsCnt = 0; + | SET VAR totalInsCnt = (SELECT x FROM y WHERE id = 1); + |END""".stripMargin + + checkError( + exception = intercept[SparkException] { + parseScript(sqlScriptText) + }, + errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_NOT_ENABLED", + parameters = Map("sqlScriptingEnabled" -> "test")) + } } From e391ee76216a7146ad631fc4245f84eb6b61d351 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 14:45:31 +0200 Subject: [PATCH 61/99] Add test for SQLConf --- .../parser/SqlScriptingParserSuite.scala | 33 +++++++++++-------- .../SqlScriptingInterpreterSuite.scala | 13 +++++--- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index d0115bab6eaca..e1eaf28cc4157 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -18,11 +18,18 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { import CatalystSqlParser._ + protected override def beforeAll(): Unit = { + System.setProperty(IS_TESTING.key, "true") + super.beforeAll() + } + test("single select") { val sqlScriptText = "SELECT 1;" val tree = parseScript(sqlScriptText) @@ -326,18 +333,18 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } test("SQL Scripting not enabled") { - val sqlScriptText = - """ - |BEGIN - | DECLARE totalInsCnt = 0; - | SET VAR totalInsCnt = (SELECT x FROM y WHERE id = 1); - |END""".stripMargin - - checkError( - exception = intercept[SparkException] { - parseScript(sqlScriptText) - }, - errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_NOT_ENABLED", - parameters = Map("sqlScriptingEnabled" -> "test")) + withSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key -> "false") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + |END""".stripMargin + checkError( + exception = intercept[SparkException] { + parseScript(sqlScriptText) + }, + errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_NOT_ENABLED", + parameters = Map("sqlScriptingEnabled" -> SQLConf.SQL_SCRIPTING_ENABLED.key)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 03f918e80982b..2bdb9fdbae13a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession @@ -39,6 +40,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + protected override def beforeAll(): Unit = { + System.setProperty(IS_TESTING.key, "true") + super.beforeAll() + } + // Tests test("select 1") { verifySqlScriptResult("SELECT 1;", Seq(Array(Row(1)))) @@ -132,10 +138,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Array.empty[Row], // declare var + Array.empty[Row], // set var + Array(Row(2)), // select ) verifySqlScriptResult(sqlScript, expected) } From ef89698f6e0ba577e241134e1e2f1d287d9ea504 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 16:57:26 +0200 Subject: [PATCH 62/99] Remove nested iterator and fix tests: --- .../parser/SqlScriptingParserSuite.scala | 6 -- .../scripting/SqlScriptingExecutionNode.scala | 66 ++++++++----------- .../scripting/SqlScriptingInterpreter.scala | 6 +- .../SqlScriptingExecutionNodeSuite.scala | 13 ++-- .../SqlScriptingInterpreterSuite.scala | 14 ++-- 5 files changed, 45 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index e1eaf28cc4157..3b1cd02e6257c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -18,18 +18,12 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { import CatalystSqlParser._ - protected override def beforeAll(): Unit = { - System.setProperty(IS_TESTING.key, "true") - super.beforeAll() - } - test("single select") { val sqlScriptText = "SELECT 1;" val tree = parseScript(sqlScriptText) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 5926d6ba1e4f8..9af4bf252287d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -45,11 +45,17 @@ sealed trait CompoundStatementExec extends Logging { */ trait LeafStatementExec extends CompoundStatementExec { - /** - * Execute the statement. - * @param session Spark session. - */ - def execute(session: SparkSession): Unit + /** + * Whether this statement has been executed during the interpretation phase. + * Example: Statements in conditions of If/Else, While, etc. + */ + var isExecuted = false + + /** + * Execute the statement. + * @param session Spark session. + */ + def execute(session: SparkSession): Unit } /** @@ -80,15 +86,9 @@ class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, override val isInternal: Boolean, - val collectResult: Boolean = true) + val shouldCollectResult: Boolean = true) extends LeafStatementExec with WithOrigin { - /** - * Whether this statement has been executed during the interpretation phase. - * Example: Statements in conditions of If/Else, While, etc. - */ - var isExecuted = false - /** * Data returned after execution. */ @@ -109,49 +109,37 @@ class SingleStatementExec( def execute(session: SparkSession): Unit = { isExecuted = true val rows = Some(Dataset.ofRows(session, parsedPlan).collect()) - if (collectResult) { + if (shouldCollectResult) { result = rows } } } /** - * Abstract class for all statements that contain nested statements. - * Implements recursive iterator logic over all child execution nodes. - * @param collection - * Collection of child execution nodes. + * Executable node for CompoundBody. + * @param statements + * Executable nodes for nested statements within the CompoundBody. + * @param session + * Spark session. */ -abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) +class CompoundBodyExec( + statements: Seq[CompoundStatementExec], + session: SparkSession) extends NonLeafStatementExec { - protected var localIterator: Iterator[CompoundStatementExec] = collection.iterator + protected var localIterator: Iterator[CompoundStatementExec] = statements.iterator protected var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None - protected lazy val treeIterator: Iterator[CompoundStatementExec] = null - - override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { - collection.foreach(_.reset()) - localIterator = collection.iterator + statements.foreach(_.reset()) + localIterator = statements.iterator curr = if (localIterator.hasNext) Some(localIterator.next()) else None } -} - -/** - * Executable node for CompoundBody. - * @param statements - * Executable nodes for nested statements within the CompoundBody. - * @param session - * Spark session. - */ -class CompoundBodyExec( - statements: Seq[CompoundStatementExec], - session: SparkSession) - extends CompoundNestedStatementIteratorExec(statements) { - override protected lazy val treeIterator: Iterator[CompoundStatementExec] = + protected lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { val childHasNext = curr match { @@ -169,7 +157,7 @@ class CompoundBodyExec( curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") - case Some(statement: SingleStatementExec) => + case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None if (!statement.isExecuted) { statement.execute(session) // Execute the leaf statement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 895c31129e67c..406707b78fb6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -37,7 +37,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { * @return * Iterator through collection of statements to be executed. */ - def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { + private def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec].getTreeIterator } @@ -71,7 +71,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { } val dropVariables = variables .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), isInternal = true, collectResult = false)) + .map(new SingleStatementExec(_, Origin(), isInternal = true, shouldCollectResult = false)) .reverse new CompoundBodyExec( body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, session) @@ -85,7 +85,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(compoundBody: CompoundBody): Iterator[Array[Row]] = { val executionPlan = buildExecutionPlan(compoundBody) executionPlan.flatMap { - case statement: SingleStatementExec if statement.collectResult + case statement: SingleStatementExec if statement.shouldCollectResult && !statement.isInternal => statement.result case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 4855f932561b3..9d447a21e198b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -36,9 +36,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { override def execute(session: SparkSession): Unit = () } - case class TestNestedStatementIterator(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) - case class TestBody(statements: Seq[CompoundStatementExec]) extends CompoundBodyExec(statements, null) @@ -50,7 +47,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { // Tests test("test body - single statement") { - val iter = TestNestedStatementIterator(Seq(TestLeafStatement("one"))).getTreeIterator + val iter = TestBody(Seq(TestLeafStatement("one"))).getTreeIterator val statements = iter.map { case TestLeafStatement(v) => v case _ => fail("Unexpected statement type") @@ -60,7 +57,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { } test("test body - no nesting") { - val iter = TestNestedStatementIterator( + val iter = TestBody( Seq( TestLeafStatement("one"), TestLeafStatement("two"), @@ -75,11 +72,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { } test("test body - nesting") { - val iter = TestNestedStatementIterator( + val iter = TestBody( Seq( - TestNestedStatementIterator(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), + TestBody(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), TestLeafStatement("three"), - TestNestedStatementIterator(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) + TestBody(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) .getTreeIterator val statements = iter.map { case TestLeafStatement(v) => v diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2bdb9fdbae13a..26a1131fa3819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.scripting -import org.apache.spark.internal.config.Tests.IS_TESTING -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession /** @@ -27,7 +28,7 @@ import org.apache.spark.sql.test.SharedSparkSession * Output from the interpreter (iterator over executable statements) is then checked - statements * are executed and output DataFrames are compared with expected outputs. */ -class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { +class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession { // Helpers private def verifySqlScriptResult(sqlText: String, expected: Seq[Array[Row]]): Unit = { val interpreter = SqlScriptingInterpreter(spark) @@ -41,8 +42,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } protected override def beforeAll(): Unit = { - System.setProperty(IS_TESTING.key, "true") super.beforeAll() + spark.conf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true") + } + + protected override def afterAll(): Unit = { + spark.conf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "false") + super.afterAll() } // Tests From 4905fa370e82b949e6fd8958de61ae7d1c9c0481 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 16:59:47 +0200 Subject: [PATCH 63/99] Change List to Seq in SparkSession sql --- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 87d590e66a927..6bd6b402f7de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -660,7 +660,7 @@ class SparkSession private( val parsedPlan = sessionState.sqlParser.parseScript(sqlText) parsedPlan match { case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => - CompoundBody(List(SingleStatement( + CompoundBody(Seq(SingleStatement( PosParameterizedQuery( singleStmtPlan.parsedPlan, args.map(lit(_).expr).toImmutableArraySeq))), label) case p => @@ -726,7 +726,7 @@ class SparkSession private( val parsedPlan = sessionState.sqlParser.parseScript(sqlText) parsedPlan match { case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => - CompoundBody(List(SingleStatement( + CompoundBody(Seq(SingleStatement( NameParameterizedQuery( singleStmtPlan.parsedPlan, args.transform((_, v) => lit(v).expr)))), label) case p => From a4f9ac65113d138615ba7a9e5bca90f7df3ab5aa Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 17:14:04 +0200 Subject: [PATCH 64/99] Update result collection logic --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 6 +++--- .../spark/sql/scripting/SqlScriptingInterpreter.scala | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 9af4bf252287d..abc5dc132937b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -47,7 +47,7 @@ trait LeafStatementExec extends CompoundStatementExec { /** * Whether this statement has been executed during the interpretation phase. - * Example: Statements in conditions of If/Else, While, etc. + * This is used to avoid re-execution of the same statement. */ var isExecuted = false @@ -85,8 +85,8 @@ trait NonLeafStatementExec extends CompoundStatementExec { class SingleStatementExec( var parsedPlan: LogicalPlan, override val origin: Origin, - override val isInternal: Boolean, - val shouldCollectResult: Boolean = true) + override val isInternal: Boolean = false, + val shouldCollectResult: Boolean = false) extends LeafStatementExec with WithOrigin { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 406707b78fb6e..8be2d921a7678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -71,7 +71,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { } val dropVariables = variables .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), isInternal = true, shouldCollectResult = false)) + .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse new CompoundBodyExec( body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, session) @@ -79,14 +79,13 @@ case class SqlScriptingInterpreter(session: SparkSession) { new SingleStatementExec( sparkStatement.parsedPlan, sparkStatement.origin, - isInternal = false) + shouldCollectResult = true) } def execute(compoundBody: CompoundBody): Iterator[Array[Row]] = { val executionPlan = buildExecutionPlan(compoundBody) executionPlan.flatMap { - case statement: SingleStatementExec if statement.shouldCollectResult - && !statement.isInternal => statement.result + case statement: SingleStatementExec if statement.shouldCollectResult => statement.result case _ => None } } From df6e5fd38b4ea27f7f5e03977dfd2678992b60c3 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 18:00:31 +0200 Subject: [PATCH 65/99] Fix access modifiers and add comments --- .../sql/scripting/SqlScriptingExecutionNode.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index abc5dc132937b..58a9e47b2b267 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -81,6 +81,9 @@ trait NonLeafStatementExec extends CompoundStatementExec { * Whether the statement originates from the SQL script or it is created during the * interpretation. Example: DropVariable statements are automatically created at the end of each * compound. + * @param shouldCollectResult + * Whether we should collect result after statement execution. Example: results from conditions + * in if-else or loops should not be collected. */ class SingleStatementExec( var parsedPlan: LogicalPlan, @@ -106,7 +109,7 @@ class SingleStatementExec( override def reset(): Unit = isExecuted = false - def execute(session: SparkSession): Unit = { + override def execute(session: SparkSession): Unit = { isExecuted = true val rows = Some(Dataset.ofRows(session, parsedPlan).collect()) if (shouldCollectResult) { @@ -127,8 +130,8 @@ class CompoundBodyExec( session: SparkSession) extends NonLeafStatementExec { - protected var localIterator: Iterator[CompoundStatementExec] = statements.iterator - protected var curr: Option[CompoundStatementExec] = + private var localIterator: Iterator[CompoundStatementExec] = statements.iterator + private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator @@ -139,7 +142,7 @@ class CompoundBodyExec( curr = if (localIterator.hasNext) Some(localIterator.next()) else None } - protected lazy val treeIterator: Iterator[CompoundStatementExec] = + private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { val childHasNext = curr match { From c758a6202e8e03d948d73436140c79335dadea7a Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 22 Jul 2024 18:54:04 +0200 Subject: [PATCH 66/99] Sync handlers with the latest changes --- .../org/apache/spark/sql/SparkSession.scala | 8 +-- .../scripting/SqlScriptingExecutionNode.scala | 55 +++++++++++++++---- .../scripting/SqlScriptingInterpreter.scala | 10 ++-- .../SqlScriptingExecutionNodeSuite.scala | 4 +- 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6bd6b402f7de9..f57514d5d6603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -659,7 +659,7 @@ class SparkSession private( val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parseScript(sqlText) parsedPlan match { - case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => + case CompoundBody(Seq(singleStmtPlan: SingleStatement), label, _, _) if args.nonEmpty => CompoundBody(Seq(SingleStatement( PosParameterizedQuery( singleStmtPlan.parsedPlan, args.map(lit(_).expr).toImmutableArraySeq))), label) @@ -670,7 +670,7 @@ class SparkSession private( } plan match { - case CompoundBody(Seq(singleStmtPlan: SingleStatement), _) => + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) => Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) case _ => // execute the plan directly if it is not a single statement @@ -725,7 +725,7 @@ class SparkSession private( val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parseScript(sqlText) parsedPlan match { - case CompoundBody(Seq(singleStmtPlan: SingleStatement), label) if args.nonEmpty => + case CompoundBody(Seq(singleStmtPlan: SingleStatement), label, _, _) if args.nonEmpty => CompoundBody(Seq(SingleStatement( NameParameterizedQuery( singleStmtPlan.parsedPlan, args.transform((_, v) => lit(v).expr)))), label) @@ -736,7 +736,7 @@ class SparkSession private( } plan match { - case CompoundBody(Seq(singleStmtPlan: SingleStatement), _) => + case CompoundBody(Seq(singleStmtPlan: SingleStatement), _, _, _) => Dataset.ofRows(self, singleStmtPlan.parsedPlan, tracker) case _ => // execute the plan directly if it is not a single statement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 730970d4298b0..8cd646876044a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -59,6 +59,16 @@ trait LeafStatementExec extends CompoundStatementExec { * @param session Spark session. */ def execute(session: SparkSession): Unit + + /** + * Whether an error was raised during the execution of this statement. + */ + var raisedError: Boolean = false + + /** + * Error state of the statement. + */ + var errorState: Option[String] = None } /** @@ -100,16 +110,6 @@ class SingleStatementExec( */ var result: Option[Array[Row]] = None - /** - * Whether an error was raised during the execution of this statement. - */ - var raisedError = false - - /** - * Error state of the statement. - */ - var errorState: Option[String] = None - /** * Get the SQL query text corresponding to this statement. * @return @@ -149,10 +149,16 @@ class SingleStatementExec( * Spark session. */ class CompoundBodyExec( + label: Option[String], statements: Seq[CompoundStatementExec], + conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(), session: SparkSession) extends NonLeafStatementExec { + private def getHandler(condition: String): Option[ErrorHandlerExec] = { + conditionHandlerMap.get(condition) + } + private var localIterator: Iterator[CompoundStatementExec] = statements.iterator private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None @@ -188,6 +194,12 @@ class CompoundBodyExec( if (!statement.isExecuted) { statement.execute(session) // Execute the leaf statement } + if (statement.raisedError) { + val handler = getHandler(statement.errorState.get).get + // Reset handler body to execute it again + handler.getHandlerBody.reset() + return handler.getTreeIterator.next() + } statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { @@ -202,3 +214,26 @@ class CompoundBodyExec( } } } + +class ErrorHandlerExec( + body: CompoundBodyExec, + handlerType: HandlerType) extends NonLeafStatementExec { + + override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator + + def getHandlerType: HandlerType = handlerType + + def getHandlerBody: CompoundBodyExec = body + + def executeAndReset(): Unit = { + execute() + reset() + } + + private def execute(): Unit = { + val iterator = body.getTreeIterator + while (iterator.hasNext) iterator.next() + } + + override def reset(): Unit = body.reset() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 29c4ce955d8b2..74be9d70ee0ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.scripting -import org.apache.spark.sql.{Row, SparkSession} import scala.collection.mutable import scala.collection.mutable.ListBuffer @@ -96,14 +95,17 @@ case class SqlScriptingInterpreter(session: SparkSession) { new CompoundBodyExec( body.label, - body.collection. - map(st => transformTreeIntoExecutable(st)) ++ dropVariables, - conditionHandlerMap, session) + body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, + conditionHandlerMap, + session) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, sparkStatement.origin, shouldCollectResult = true) + case _ => + throw new UnsupportedOperationException( + s"Unsupported operation in the execution plan.") } def execute(compoundBody: CompoundBody): Iterator[Array[Row]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 9d447a21e198b..fc08ec1dfb3c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin +import scala.collection.mutable + /** * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. * Execution nodes are constructed manually and iterated through. @@ -37,7 +39,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite { } case class TestBody(statements: Seq[CompoundStatementExec]) - extends CompoundBodyExec(statements, null) + extends CompoundBodyExec(None, statements, mutable.HashMap(), null) case class TestSparkStatementWithPlan(testVal: String) extends SingleStatementExec( From a6571e3b186d0d481de05b8186b7025d6e5bbb3d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 23 Jul 2024 12:54:33 +0200 Subject: [PATCH 67/99] Add check if handler is defined --- .../sql/scripting/SqlScriptingExecutionNode.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 8cd646876044a..2561e73c3304c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -69,6 +69,11 @@ trait LeafStatementExec extends CompoundStatementExec { * Error state of the statement. */ var errorState: Option[String] = None + + /** + * Error raised during statement execution. + */ + var error: Option[SparkThrowable] = None } /** @@ -134,6 +139,7 @@ class SingleStatementExec( // TODO: check handlers for error conditions raisedError = true errorState = Some(e.getSqlState) + error = Some(e) case _: Throwable => raisedError = true errorState = Some("UNKNOWN") @@ -195,10 +201,11 @@ class CompoundBodyExec( statement.execute(session) // Execute the leaf statement } if (statement.raisedError) { - val handler = getHandler(statement.errorState.get).get - // Reset handler body to execute it again - handler.getHandlerBody.reset() - return handler.getTreeIterator.next() + val handler = getHandler(statement.errorState.get) + if (handler.isDefined) { + handler.get.getHandlerBody.reset() + return handler.get.getTreeIterator.next() + } } statement case Some(body: NonLeafStatementExec) => From b90e1590f6b8df8369c42939b20beb3c3c944e6d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 23 Jul 2024 12:57:35 +0200 Subject: [PATCH 68/99] Remove isExecuted flag because it is not necessary --- .../sql/scripting/SqlScriptingExecutionNode.scala | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 58a9e47b2b267..03997331bff4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -45,12 +45,6 @@ sealed trait CompoundStatementExec extends Logging { */ trait LeafStatementExec extends CompoundStatementExec { - /** - * Whether this statement has been executed during the interpretation phase. - * This is used to avoid re-execution of the same statement. - */ - var isExecuted = false - /** * Execute the statement. * @param session Spark session. @@ -107,10 +101,9 @@ class SingleStatementExec( origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) } - override def reset(): Unit = isExecuted = false + override def reset(): Unit = () override def execute(session: SparkSession): Unit = { - isExecuted = true val rows = Some(Dataset.ofRows(session, parsedPlan).collect()) if (shouldCollectResult) { result = rows @@ -162,9 +155,7 @@ class CompoundBodyExec( "No more elements to iterate through in the current SQL compound statement.") case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None - if (!statement.isExecuted) { - statement.execute(session) // Execute the leaf statement - } + statement.execute(session) // Execute the leaf statement statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { From f452cc10003f9af8ee87442e3f069c1284a6256b Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 23 Jul 2024 16:40:33 +0200 Subject: [PATCH 69/99] Add test and separate logic into functions --- .../scripting/SqlScriptingExecutionNode.scala | 58 +++++++++++-------- .../scripting/SqlScriptingInterpreter.scala | 2 + .../SqlScriptingInterpreterSuite.scala | 30 ++++++++++ 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 5bab0196205a3..e454596a0d8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -68,6 +68,8 @@ trait LeafStatementExec extends CompoundStatementExec { * Error raised during statement execution. */ var error: Option[SparkThrowable] = None + + var rethrow: Option[Throwable] = None } /** @@ -119,7 +121,12 @@ class SingleStatementExec( origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) } - override def reset(): Unit = () + override def reset(): Unit = { + raisedError = false + errorState = None + error = None + result = None // Should we do this? + } def execute(session: SparkSession): Unit = { try { @@ -133,9 +140,15 @@ class SingleStatementExec( raisedError = true errorState = Some(e.getSqlState) error = Some(e) - case _: Throwable => + e match { + case throwable: Throwable => + rethrow = Some(throwable) + case _ => + } + case thr: Throwable => raisedError = true errorState = Some("UNKNOWN") + rethrow = Some(thr) } } } @@ -155,7 +168,22 @@ class CompoundBodyExec( extends NonLeafStatementExec { private def getHandler(condition: String): Option[ErrorHandlerExec] = { - conditionHandlerMap.get(condition) + var ret = conditionHandlerMap.get(condition) + if (ret.isEmpty) { + ret = conditionHandlerMap.get("UNKNOWN") + } + ret + } + + private def handleError(statement: LeafStatementExec): CompoundStatementExec = { + if (statement.raisedError) { + getHandler(statement.errorState.get).foreach { handler => + statement.reset() // Clear all flags and result + handler.reset() + return handler.getTreeIterator.next() + } + } + statement } private var localIterator: Iterator[CompoundStatementExec] = statements.iterator @@ -191,17 +219,11 @@ class CompoundBodyExec( case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) // Execute the leaf statement - if (statement.raisedError) { - val handler = getHandler(statement.errorState.get) - if (handler.isDefined) { - handler.get.getHandlerBody.reset() - return handler.get.getTreeIterator.next() - } - } - statement + handleError(statement) case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { - body.getTreeIterator.next() + val statement = body.getTreeIterator.next() + handleError(statement.asInstanceOf[LeafStatementExec]) } else { curr = if (localIterator.hasNext) Some(localIterator.next()) else None next() @@ -221,17 +243,5 @@ class ErrorHandlerExec( def getHandlerType: HandlerType = handlerType - def getHandlerBody: CompoundBodyExec = body - - def executeAndReset(): Unit = { - execute() - reset() - } - - private def execute(): Unit = { - val iterator = body.getTreeIterator - while (iterator.hasNext) iterator.next() - } - override def reset(): Unit = body.reset() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 74be9d70ee0ce..eaba017ceb043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -111,6 +111,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { def execute(compoundBody: CompoundBody): Iterator[Array[Row]] = { val executionPlan = buildExecutionPlan(compoundBody) executionPlan.flatMap { + case statement: SingleStatementExec if statement.raisedError => + throw statement.rethrow.get case statement: SingleStatementExec if statement.shouldCollectResult => statement.result case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 26a1131fa3819..ea7c16a001dbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -227,4 +227,34 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession ) verifySqlScriptResult(sqlScript, expected) } + + test("handler") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SET VAR flag = 1; + | END; + | BEGIN + | SELECT 1; + | BEGIN + | SELECT 2; + | SELECT 1/0; + | END; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(1)), // select + Array(Row(2)), // select + Array.empty[Row], // select 1/0 (error) + Array(Row(1)), // select + ) + verifySqlScriptResult(sqlScript, expected) + } } From 574b3da873fc332ef5526b489e6f3b9d6d9f6ffe Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 23 Jul 2024 18:12:52 +0200 Subject: [PATCH 70/99] Add leave statement execution node --- .../sql/catalyst/parser/AstBuilder.scala | 3 +-- .../parser/SqlScriptingLogicalOperators.scala | 6 ++++++ .../scripting/SqlScriptingExecutionNode.scala | 19 +++++++++++++++++++ .../scripting/SqlScriptingInterpreter.scala | 3 +-- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index fa0a96f0262ff..bdd52bde89242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -227,14 +227,13 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = { val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]] + val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) val body = Option(ctx.compoundBody()).map(visit).getOrElse { val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan))) }.asInstanceOf[CompoundBody] - val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) - ErrorHandler(conditions, body, handlerType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index a638da801e9cb..ee961589eef81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -88,3 +88,9 @@ case class ErrorHandler( conditions: Seq[String], body: CompoundBody, handlerType: HandlerType) extends CompoundPlanStatement + +/** + * Logical operator for a leave statement. + * @param label Label of the CompoundBody leave statement should exit. + */ +case class BatchLeaveStatement(label: String) extends CompoundPlanStatement \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index e454596a0d8e6..fc53d3f46e694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -216,6 +216,14 @@ class CompoundBodyExec( curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") + case Some(leave: LeaveStatementExec) => + if (leave.used) { + curr = None + if (label.getOrElse("").equals(leave.getLabel)) { + leave.execute(session) + } + } + leave case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) // Execute the leaf statement @@ -245,3 +253,14 @@ class ErrorHandlerExec( override def reset(): Unit = body.reset() } + +class LeaveStatementExec(val label: String) extends LeafStatementExec { + + var used: Boolean = false + + def getLabel: String = label + + override def execute(session: SparkSession): Unit = used = true + + override def reset(): Unit = used = false +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index eaba017ceb043..f96efc028aef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.scripting import scala.collection.mutable import scala.collection.mutable.ListBuffer - import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, SingleStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorHandler, HandlerType, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin From 156a2d789bfd277655f55acdc63359011d2d2332 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 25 Jul 2024 12:39:36 +0200 Subject: [PATCH 71/99] Add handler logic and stopIteration flag --- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../parser/SqlScriptingLogicalOperators.scala | 4 +- .../scripting/SqlScriptingExecutionNode.scala | 80 +++++++++++++----- .../scripting/SqlScriptingInterpreter.scala | 83 ++++++++++++------- .../SqlScriptingInterpreterSuite.scala | 47 +++++++++-- 5 files changed, 155 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdd52bde89242..180044884f5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -140,7 +140,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val buff = ListBuffer[CompoundPlanStatement]() val handlers = ListBuffer[ErrorHandler]() val conditions = mutable.HashMap[String, String]() - val sqlstates = mutable.Set[String]() + val sqlStates = mutable.Set[String]() ctx.compoundStatements.forEach(compoundStatement => { val stmt = visit(compoundStatement).asInstanceOf[CompoundPlanStatement] @@ -149,9 +149,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case handler: ErrorHandler => handlers += handler case condition: ErrorCondition => assert(!conditions.contains(condition.conditionName)) // Check for duplicate names. - assert(!sqlstates.contains(condition.value)) // Check for duplicate sqlstates. + assert(!sqlStates.contains(condition.value)) // Check for duplicate sqlStates. conditions += condition.conditionName -> condition.value - sqlstates += condition.value + sqlStates += condition.value case s => buff += s } }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index ee961589eef81..eecbab92b3120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -74,8 +74,8 @@ case class ErrorCondition( value: String) extends CompoundPlanStatement object HandlerType extends Enumeration { - type HandlerType = Value - val EXIT, CONTINUE = Value + type HandlerType = Value + val EXIT, CONTINUE = Value } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index fc53d3f46e694..20410c418f7ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.parser.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -69,6 +68,9 @@ trait LeafStatementExec extends CompoundStatementExec { */ var error: Option[SparkThrowable] = None + /** + * Throwable to rethrow after the statement execution if the error is not handled. + */ var rethrow: Option[Throwable] = None } @@ -136,7 +138,6 @@ class SingleStatementExec( } } catch { case e: SparkThrowable => - // TODO: check handlers for error conditions raisedError = true errorState = Some(e.getSqlState) error = Some(e) @@ -145,10 +146,10 @@ class SingleStatementExec( rethrow = Some(throwable) case _ => } - case thr: Throwable => + case throwable: Throwable => raisedError = true errorState = Some("UNKNOWN") - rethrow = Some(thr) + rethrow = Some(throwable) } } } @@ -175,20 +176,48 @@ class CompoundBodyExec( ret } + /** + * Handle error raised during the execution of the statement. + * @param statement statement that possibly raised the error + * @return pass through the statement + */ private def handleError(statement: LeafStatementExec): CompoundStatementExec = { if (statement.raisedError) { getHandler(statement.errorState.get).foreach { handler => statement.reset() // Clear all flags and result handler.reset() - return handler.getTreeIterator.next() + curr = Some(handler.getHandlerBody) + return handler.getHandlerBody } } statement } + /** + * Check if the leave statement was used, if it is not used stop iterating surrounding + * [[CompoundBodyExec]] and move iterator forward. If the label of the block matches the label of + * the leave statement, mark the leave statement as used. + * @param leave leave statement + * @return pass through the leave statement + */ + private def handleLeave(leave: LeaveStatementExec): LeaveStatementExec = { + if (!leave.used) { + // Hard stop the iteration of the current begin/end block + stopIteration = true + // If label of the block matches the label of the leave statement, + // mark the leave statement as used + if (label.getOrElse("").equals(leave.getLabel)) { + leave.used = true + } + } + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + leave + } + private var localIterator: Iterator[CompoundStatementExec] = statements.iterator private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None + private var stopIteration: Boolean = false // hard stop iteration flag def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator @@ -208,7 +237,7 @@ class CompoundBodyExec( case _ => throw SparkException.internalError( "Unknown statement type encountered during SQL script interpretation.") } - localIterator.hasNext || childHasNext + (localIterator.hasNext || childHasNext) && !stopIteration } @scala.annotation.tailrec @@ -217,21 +246,25 @@ class CompoundBodyExec( case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") case Some(leave: LeaveStatementExec) => - if (leave.used) { - curr = None - if (label.getOrElse("").equals(leave.getLabel)) { - leave.execute(session) - } - } - leave + handleLeave(leave) case Some(statement: LeafStatementExec) => - curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement.execute(session) // Execute the leaf statement - handleError(statement) + if (!statement.raisedError) { + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + } + handleError(statement) // Handle error if raised case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { - val statement = body.getTreeIterator.next() - handleError(statement.asInstanceOf[LeafStatementExec]) + val statement = body.getTreeIterator.next() // Get next statement from the child node + statement match { + case leave: LeaveStatementExec => + handleLeave(leave) + case leafStatement: LeafStatementExec => + // This check is done to handler error in surrounding begin/end block + // if it was not handled in the nested block + handleError(leafStatement) // Handle error if raised + case nonLeafStatement: NonLeafStatementExec => nonLeafStatement + } } else { curr = if (localIterator.hasNext) Some(localIterator.next()) else None next() @@ -243,24 +276,27 @@ class CompoundBodyExec( } } -class ErrorHandlerExec( - body: CompoundBodyExec, - handlerType: HandlerType) extends NonLeafStatementExec { +class ErrorHandlerExec(body: CompoundBodyExec) extends NonLeafStatementExec { override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator - def getHandlerType: HandlerType = handlerType + def getHandlerBody: CompoundBodyExec = body override def reset(): Unit = body.reset() } +/** + * Executable node for Leave statement. + * @param label + * Label of the [[CompoundBodyExec]] that should be exited. + */ class LeaveStatementExec(val label: String) extends LeafStatementExec { var used: Boolean = false def getLabel: String = label - override def execute(session: SparkSession): Unit = used = true + override def execute(session: SparkSession): Unit = () override def reset(): Unit = used = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index f96efc028aef9..aeec494af61ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.scripting import scala.collection.mutable import scala.collection.mutable.ListBuffer + import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, ErrorHandler, HandlerType, SingleStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, HandlerType, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -56,6 +57,55 @@ case class SqlScriptingInterpreter(session: SparkSession) { case _ => None } + private def transformBodyIntoExec( + compoundBody: CompoundBody, + isExitHandler: Boolean = false, + label: String = ""): CompoundBodyExec = { + val variables = compoundBody.collection.flatMap { + case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) + case _ => None + } + val dropVariables = variables + .map(varName => DropVariable(varName, ifExists = true)) + .map(new SingleStatementExec(_, Origin(), isInternal = true)) + .reverse + + val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() + val handlers = ListBuffer[ErrorHandlerExec]() + compoundBody.handlers.foreach(handler => { + val handlerBodyExec = + transformBodyIntoExec(handler.body, + handler.handlerType == HandlerType.EXIT, + compoundBody.label.get) + val handlerExec = new ErrorHandlerExec(handlerBodyExec) + + handler.conditions.foreach(condition => { + val conditionValue = compoundBody.conditions.getOrElse(condition, condition) + conditionHandlerMap.put(conditionValue, handlerExec) + }) + + handlers += handlerExec + }) + + if (isExitHandler) { + val leave = new LeaveStatementExec(label) + val stmts = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ + dropVariables :+ leave + + return new CompoundBodyExec( + compoundBody.label, + stmts, + conditionHandlerMap, + session) + } + + new CompoundBodyExec( + compoundBody.label, + compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, + conditionHandlerMap, + session) + } + /** * Transform the parsed tree to the executable node. * @param node @@ -67,36 +117,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { node match { case body: CompoundBody => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = body.collection.flatMap { - case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) - case _ => None - } - val dropVariables = variables - .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), isInternal = true)) - .reverse - - val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() - val handlers = ListBuffer[ErrorHandlerExec]() - body.handlers.foreach(handler => { - val handlerBodyExec = transformTreeIntoExecutable(handler.body). - asInstanceOf[CompoundBodyExec] - val handlerExec = - new ErrorHandlerExec(handlerBodyExec, handler.handlerType) - - handler.conditions.foreach(condition => { - val conditionValue = body.conditions.getOrElse(condition, condition) - conditionHandlerMap.put(conditionValue, handlerExec) - }) - - handlers += handlerExec - }) - - new CompoundBodyExec( - body.label, - body.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, - conditionHandlerMap, - session) + transformBodyIntoExec(body) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index ea7c16a001dbc..55de4ed1d3aab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -228,7 +228,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession verifySqlScriptResult(sqlScript, expected) } - test("handler") { + test("handler - continue") { val sqlScript = """ |BEGIN @@ -236,12 +236,13 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession | DECLARE zero_division CONDITION FOR '22012'; | DECLARE CONTINUE HANDLER FOR zero_division | BEGIN + | SELECT flag; | SET VAR flag = 1; | END; | BEGIN - | SELECT 1; + | SELECT 2; | BEGIN - | SELECT 2; + | SELECT 3; | SELECT 1/0; | END; | END; @@ -250,10 +251,42 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( Array.empty[Row], // declare var - Array(Row(1)), // select - Array(Row(2)), // select - Array.empty[Row], // select 1/0 (error) - Array(Row(1)), // select + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(1)), // select + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("handler - exit") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE EXIT HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | SELECT 1/0; + | SELECT 4; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(1)), // select flag from the outer body ) verifySqlScriptResult(sqlScript, expected) } From 6a37dcc8658da4cbffb76d46833cd30937ac01e0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 26 Jul 2024 17:15:44 +0200 Subject: [PATCH 72/99] Add tests --- .../scripting/SqlScriptingExecutionNode.scala | 9 +- .../SqlScriptingInterpreterSuite.scala | 126 +++++++++++++++++- 2 files changed, 126 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 20410c418f7ce..42e57daba16e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -181,13 +181,13 @@ class CompoundBodyExec( * @param statement statement that possibly raised the error * @return pass through the statement */ - private def handleError(statement: LeafStatementExec): CompoundStatementExec = { + private def handleError(statement: LeafStatementExec): LeafStatementExec = { if (statement.raisedError) { getHandler(statement.errorState.get).foreach { handler => statement.reset() // Clear all flags and result handler.reset() curr = Some(handler.getHandlerBody) - return handler.getHandlerBody +// return handler.getHandlerBody } } statement @@ -217,6 +217,7 @@ class CompoundBodyExec( private var localIterator: Iterator[CompoundStatementExec] = statements.iterator private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None + private var prev: Option[CompoundStatementExec] = None private var stopIteration: Boolean = false // hard stop iteration flag def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator @@ -260,9 +261,9 @@ class CompoundBodyExec( case leave: LeaveStatementExec => handleLeave(leave) case leafStatement: LeafStatementExec => - // This check is done to handler error in surrounding begin/end block + // This check is done to handle error in surrounding begin/end block // if it was not handled in the nested block - handleError(leafStatement) // Handle error if raised + handleError(leafStatement) case nonLeafStatement: NonLeafStatementExec => nonLeafStatement } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 55de4ed1d3aab..2550c0ce9b818 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -228,7 +228,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession verifySqlScriptResult(sqlScript, expected) } - test("handler - continue") { + test("handler - continue resolve in the same block") { val sqlScript = """ |BEGIN @@ -239,13 +239,87 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession | SELECT flag; | SET VAR flag = 1; | END; + | SELECT 2; + | SELECT 3; + | SELECT 1/0; + | SELECT 4; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(4)), // select + Array(Row(1)), // select + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("handler - continue resolve in outer block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division | BEGIN - | SELECT 2; + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | BEGIN + | SELECT 3; + | BEGIN + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT 7; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(5)), // select + Array(Row(6)), // select + Array(Row(7)), // select + Array(Row(1)), // select + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("handler - continue resolve in the same block nested") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | SELECT 2; + | BEGIN + | SELECT 3; | BEGIN - | SELECT 3; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 4; | SELECT 1/0; + | SELECT 5; | END; + | SELECT 6; | END; + | SELECT 7; | SELECT flag; |END |""".stripMargin @@ -253,19 +327,23 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array.empty[Row], // declare var Array(Row(2)), // select Array(Row(3)), // select + Array(Row(4)), // select Array(Row(-1)), // select flag Array.empty[Row], // set flag + Array(Row(5)), // select + Array(Row(6)), // select + Array(Row(7)), // select Array(Row(1)), // select ) verifySqlScriptResult(sqlScript, expected) } - test("handler - exit") { + test("handler - exit resolve in the same block") { val sqlScript = """ |BEGIN + | DECLARE flag INT = -1; | BEGIN - | DECLARE flag INT = -1; | DECLARE zero_division CONDITION FOR '22012'; | DECLARE EXIT HANDLER FOR zero_division | BEGIN @@ -290,4 +368,42 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession ) verifySqlScriptResult(sqlScript, expected) } + + test("handler - exit resolve in outer block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | BEGIN + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE EXIT HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 3; + | BEGIN + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + // skip select 5 + // skip select 6 + Array(Row(1)), // select flag from the outer body + ) + verifySqlScriptResult(sqlScript, expected) + } } From fa82a44deca5c7aad00dd33ba05a891266ff7a8c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 26 Jul 2024 17:30:51 +0200 Subject: [PATCH 73/99] Add test --- .../scripting/SqlScriptingExecutionNode.scala | 2 -- .../SqlScriptingInterpreterSuite.scala | 29 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 42e57daba16e5..958987040301d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -187,7 +187,6 @@ class CompoundBodyExec( statement.reset() // Clear all flags and result handler.reset() curr = Some(handler.getHandlerBody) -// return handler.getHandlerBody } } statement @@ -217,7 +216,6 @@ class CompoundBodyExec( private var localIterator: Iterator[CompoundStatementExec] = statements.iterator private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None - private var prev: Option[CompoundStatementExec] = None private var stopIteration: Boolean = false // hard stop iteration flag def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2550c0ce9b818..e74b40d112c07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -406,4 +406,33 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession ) verifySqlScriptResult(sqlScript, expected) } + + test("chained begin end blocks") { + val sqlScript = + """ + |BEGIN + | BEGIN + | SELECT 1; + | SELECT 2; + | END; + | BEGIN + | SELECT 3; + | SELECT 4; + | END; + | BEGIN + | SELECT 5; + | SELECT 6; + | END; + |END + |""".stripMargin + val expected = Seq( + Array(Row(1)), // select + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(5)), // select + Array(Row(6)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } } From 56ca9fdbd209fccac7af6c70871a930ea52de4ec Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 6 Aug 2024 09:31:37 +0200 Subject: [PATCH 74/99] Merge with the latest changes and update tests: --- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 13 ++-- .../scripting/SqlScriptingExecutionNode.scala | 5 +- .../scripting/SqlScriptingInterpreter.scala | 15 ++-- .../SqlScriptingExecutionNodeSuite.scala | 68 +++++++++---------- .../SqlScriptingInterpreterSuite.scala | 6 +- 6 files changed, 52 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5dd850afcfa4c..650a30ae407f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -223,7 +223,9 @@ class AstBuilder extends DataTypeAstBuilder SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan]) }.getOrElse { val stmt = Option(ctx.beginEndCompoundBlock()). - getOrElse(Option(ctx.declareHandler()).getOrElse(ctx.declareCondition())) + getOrElse(Option(ctx.declareHandler()). + getOrElse(Option(ctx.declareCondition()). + getOrElse(ctx.ifElseStatement()))) visit(stmt).asInstanceOf[CompoundPlanStatement] } } @@ -242,6 +244,7 @@ class AstBuilder extends DataTypeAstBuilder buff.toSeq } } + } override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = { IfElseStatement( @@ -255,7 +258,6 @@ class AstBuilder extends DataTypeAstBuilder elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)) ) } - } override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { val conditionName = ctx.multipartIdentifier().getText diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3f417644082c3..0d932e0038266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,23 +17,20 @@ package org.apache.spark.sql.catalyst.plans -import java.util.IdentityHashMap - -import scala.collection.mutable - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.RuleId -import org.apache.spark.sql.catalyst.rules.UnknownRuleId -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.rules.{RuleId, UnknownRuleId} import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} -import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag, TreePatternBits} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.BitSet +import java.util.IdentityHashMap +import scala.collection.mutable + /** * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class * defines some basic properties of a query plan node, as well as some new transform APIs to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 737d63340c41c..1a367483d0667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -97,9 +97,6 @@ trait NonLeafStatementExec extends CompoundStatementExec { session: SparkSession, statement: LeafStatementExec): Boolean = statement match { case statement: SingleStatementExec => - assert(!statement.isExecuted) - statement.isExecuted = true - // DataFrame evaluates to True if it is single row, single column // of boolean type with value True. val df = Dataset.ofRows(session, statement.parsedPlan) @@ -190,7 +187,7 @@ class SingleStatementExec( * Spark session. */ class CompoundBodyExec( - label: Option[String], + label: Option[String] = None, statements: Seq[CompoundStatementExec], conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(), session: SparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index f8171829582d6..87864a89339ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -42,10 +42,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { * @return * Iterator through collection of statements to be executed. */ - def buildExecutionPlan( - compound: CompoundBody, - session: SparkSession): Iterator[CompoundStatementExec] = { - transformTreeIntoExecutable(compound, session).asInstanceOf[CompoundBodyExec].getTreeIterator + def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { + transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec].getTreeIterator } /** @@ -115,13 +113,10 @@ case class SqlScriptingInterpreter(session: SparkSession) { * * @param node * Root node of the parsed tree. - * @param session - * Spark session that SQL script is executed within. * @return * Executable statement. */ - private def transformTreeIntoExecutable( - node: CompoundPlanStatement, session: SparkSession): CompoundStatementExec = + private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec = node match { case body: CompoundBody => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. @@ -130,9 +125,9 @@ case class SqlScriptingInterpreter(session: SparkSession) { val conditionsExec = conditions.map(condition => new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)) val conditionalBodiesExec = conditionalBodies.map(body => - transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body).asInstanceOf[CompoundBodyExec]) val unconditionalBodiesExec = elseBody.map(body => - transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body).asInstanceOf[CompoundBodyExec]) new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) case sparkStatement: SingleStatement => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 8406abe515399..9220e84615dd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -58,13 +58,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi // Tests test("test body - single statement") { - val iter = new CompoundBodyExec(Seq(TestLeafStatement("one"))).getTreeIterator + val iter = TestBody(Seq(TestLeafStatement("one"))).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("one")) } test("test body - no nesting") { - val iter = new CompoundBodyExec( + val iter = TestBody( Seq( TestLeafStatement("one"), TestLeafStatement("two"), @@ -75,26 +75,26 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("test body - nesting") { - val iter = new CompoundBodyExec( + val iter = TestBody( Seq( - new CompoundBodyExec(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), + TestBody(Seq(TestLeafStatement("one"), TestLeafStatement("two"))), TestLeafStatement("three"), - new CompoundBodyExec(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) + TestBody(Seq(TestLeafStatement("four"), TestLeafStatement("five"))))) .getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("one", "two", "three", "four", "five")) } test("if else - enter body of the IF clause") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = true, description = "con1") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + TestBody(Seq(TestLeafStatement("body1"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body2")))), session = spark ) )).getTreeIterator @@ -103,15 +103,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else - enter body of the ELSE clause") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + TestBody(Seq(TestLeafStatement("body1"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body2")))), session = spark ) )).getTreeIterator @@ -120,17 +120,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - enter body of the IF clause") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = true, description = "con1"), TestIfElseCondition(condVal = false, description = "con2") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))), session = spark ) )).getTreeIterator @@ -139,17 +139,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - enter body of the ELSE IF clause") { - val iter = new CompoundBodyExec(Seq( + val iter = new TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = true, description = "con2") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))), session = spark ) )).getTreeIterator @@ -158,7 +158,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - enter body of the second ELSE IF clause") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), @@ -166,11 +166,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi TestIfElseCondition(condVal = true, description = "con3") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))), - new CompoundBodyExec(Seq(TestLeafStatement("body3"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))), + TestBody(Seq(TestLeafStatement("body3"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body4")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body4")))), session = spark ) )).getTreeIterator @@ -179,17 +179,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - enter body of the ELSE clause") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = false, description = "con2") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))) ), - elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))), session = spark ) )).getTreeIterator @@ -198,15 +198,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - without else (successful check)") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = true, description = "con2") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))) ), elseBody = None, session = spark @@ -217,15 +217,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - without else (unsuccessful checks)") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = false, description = "con2") ), conditionalBodies = Seq( - new CompoundBodyExec(Seq(TestLeafStatement("body1"))), - new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + TestBody(Seq(TestLeafStatement("body1"))), + TestBody(Seq(TestLeafStatement("body2"))) ), elseBody = None, session = spark diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2e61dc7bda742..e600927e998cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -445,7 +445,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(42))) + val expected = Seq(Array(Row(42))) verifySqlScriptResult(commands, expected) } @@ -462,7 +462,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(42))) + val expected = Seq(Array(Row(42))) verifySqlScriptResult(commands, expected) } @@ -536,7 +536,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |END |""".stripMargin - val expected = Seq(Seq(Row(44))) + val expected = Seq(Array(Row(44))) verifySqlScriptResult(commands, expected) } From 29b4f3d01efef78d2e6927efc7a0e03db4bb5d0b Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 7 Aug 2024 18:13:40 +0200 Subject: [PATCH 75/99] Fix continue handler and add check for duplicate handlers --- .../resources/error/error-conditions.json | 6 ++++ .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../spark/sql/errors/SqlScriptingErrors.scala | 8 +++++ .../scripting/SqlScriptingExecutionNode.scala | 36 +++++++++++++++---- .../scripting/SqlScriptingInterpreter.scala | 10 ++++-- .../SqlScriptingInterpreterSuite.scala | 27 ++++++++++++++ 6 files changed, 79 insertions(+), 10 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 37675976ea7b2..04a1f3de1c327 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1097,6 +1097,12 @@ ], "sqlState" : "42614" }, + "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE" : { + "message" : [ + "Found duplicate handlers for the same SQL state . Please, remove one of them." + ], + "sqlState" : "42710" + }, "DUPLICATE_KEY" : { "message" : [ "Found duplicate keys ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bd1c0c72e36cb..4bbc6289c438d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -121,7 +121,7 @@ class AstBuilder extends DataTypeAstBuilder ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) { Option(ctx.singleCompoundStatement()).map { s => if (!SQLConf.get.sqlScriptingEnabled) { - throw SqlScriptingErrors.sqlScriptingNotEnabled() + throw SqlScriptingErrors.sqlScriptingNotEnabled(CurrentOrigin.get) } visit(s).asInstanceOf[CompoundBody] }.getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 132bc6e250a9c..5bc1fd97929cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -51,6 +51,14 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("sqlScriptingEnabled" -> SQLConf.SQL_SCRIPTING_ENABLED.key)) } + def duplicateHandlerForSameSqlState(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + def variableDeclarationNotAllowedInScope( origin: Origin, varName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 1a367483d0667..d1d791056dac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -144,8 +144,13 @@ class SingleStatementExec( * SQL query text. */ def getText: String = { - assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) - origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) +// assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) + try { + origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) + } catch { + case e: Exception => + "DROP VARIABLE" + } } override def reset(): Unit = { @@ -211,6 +216,7 @@ class CompoundBodyExec( getHandler(statement.errorState.get).foreach { handler => statement.reset() // Clear all flags and result handler.reset() + returnHere = curr curr = Some(handler.getHandlerBody) } } @@ -242,6 +248,7 @@ class CompoundBodyExec( private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None private var stopIteration: Boolean = false // hard stop iteration flag + private var returnHere: Option[CompoundStatementExec] = None def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator @@ -261,7 +268,7 @@ class CompoundBodyExec( case _ => throw SparkException.internalError( "Unknown statement type encountered during SQL script interpretation.") } - (localIterator.hasNext || childHasNext) && !stopIteration + (localIterator.hasNext || childHasNext || returnHere.isDefined) && !stopIteration } @scala.annotation.tailrec @@ -273,9 +280,7 @@ class CompoundBodyExec( handleLeave(leave) case Some(statement: LeafStatementExec) => statement.execute(session) // Execute the leaf statement - if (!statement.raisedError) { - curr = if (localIterator.hasNext) Some(localIterator.next()) else None - } + curr = if (localIterator.hasNext) Some(localIterator.next()) else None handleError(statement) // Handle error if raised case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { @@ -290,7 +295,12 @@ class CompoundBodyExec( case nonLeafStatement: NonLeafStatementExec => nonLeafStatement } } else { - curr = if (localIterator.hasNext) Some(localIterator.next()) else None + if (returnHere.isDefined) { + curr = returnHere + returnHere = None + } else { + curr = if (localIterator.hasNext) Some(localIterator.next()) else None + } next() } case _ => throw SparkException.internalError( @@ -325,6 +335,18 @@ class LeaveStatementExec(val label: String) extends LeafStatementExec { override def reset(): Unit = used = false } +/** + * Executable node for Continue statement. + */ +class ContinueStatementExec() extends LeafStatementExec { + + var used: Boolean = false + + override def execute(session: SparkSession): Unit = () + + override def reset(): Unit = used = false +} + /** * Executable node for IfElseStatement. * @param conditions Collection of executable conditions. First condition corresponds to IF clause, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 87864a89339ef..1c81ea259a53b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, HandlerType, IfElseStatement, SingleStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} -import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.errors.SqlScriptingErrors /** * SQL scripting interpreter - builds SQL script execution plan. @@ -83,7 +84,12 @@ case class SqlScriptingInterpreter(session: SparkSession) { handler.conditions.foreach(condition => { val conditionValue = compoundBody.conditions.getOrElse(condition, condition) - conditionHandlerMap.put(conditionValue, handlerExec) + conditionHandlerMap.get(conditionValue) match { + case Some(_) => + throw SqlScriptingErrors.duplicateHandlerForSameSqlState( + CurrentOrigin.get, conditionValue) + case None => conditionHandlerMap.put(conditionValue, handlerExec) + } }) handlers += handlerExec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e600927e998cd..e019974174297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -228,6 +229,32 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession verifySqlScriptResult(sqlScript, expected) } + test("duplicate handler") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SET VAR flag = 1; + | END; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + errorClass = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE", + parameters = Map("sqlState" -> "22012")) + } + test("handler - continue resolve in the same block") { val sqlScript = """ From 17282c45946516de938c7d0416a6810e0ecc2321 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 8 Aug 2024 17:08:56 +0200 Subject: [PATCH 76/99] Add check for duplicate sqlstate inside condition value list in signle handler --- .../src/main/resources/error/error-conditions.json | 6 ++++++ .../spark/sql/catalyst/parser/AstBuilder.scala | 8 +++++++- .../spark/sql/errors/SqlScriptingErrors.scala | 8 ++++++++ .../catalyst/parser/SqlScriptingParserSuite.scala | 14 ++++++++++++++ .../sql/scripting/SqlScriptingExecutionNode.scala | 11 ++++++----- 5 files changed, 41 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 04a1f3de1c327..1f8c55d5e6782 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1127,6 +1127,12 @@ }, "sqlState" : "4274K" }, + "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER" : { + "message" : [ + "Found duplicate SQL state for the same handler. Please, remove one of them." + ], + "sqlState" : "42710" + }, "EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED" : { "message" : [ "Previous node emitted a row with eventTime= which is older than current_watermark_value=", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4bbc6289c438d..2302e29d3ccc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -245,8 +245,14 @@ class AstBuilder extends DataTypeAstBuilder Option(ctx.SQLEXCEPTION()).map(_ => Seq("SQLEXCEPTION")).getOrElse { Option(ctx.NOT()).map(_ => Seq("NOT FOUND")).getOrElse { val buff = ListBuffer[String]() + val seen = scala.collection.mutable.Set[String]() ctx.conditionValues.forEach { conditionValue => - buff += visit(conditionValue).asInstanceOf[String] + val elem = visit(conditionValue).asInstanceOf[String] + if (seen(elem)) { + throw SqlScriptingErrors.duplicateSqlStateForSameHandler(CurrentOrigin.get, elem) + } + buff += elem + seen += elem } buff.toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 5bc1fd97929cb..ee710a5c62439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -59,6 +59,14 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("sqlState" -> sqlState)) } + def duplicateSqlStateForSameHandler(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + def variableDeclarationNotAllowedInScope( origin: Origin, varName: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index db563c0eaf248..77a50d521e609 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -396,6 +396,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.handlers.head.isInstanceOf[ErrorHandler]) } + test("declare handler duplicate sqlState") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test, test BEGIN SELECT 1; END; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER", + parameters = Map("sqlState" -> "test")) + } + test("SQL Scripting not enabled") { withSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key -> "false") { val sqlScriptText = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index d1d791056dac7..b4371d932ff1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -199,11 +199,12 @@ class CompoundBodyExec( extends NonLeafStatementExec { private def getHandler(condition: String): Option[ErrorHandlerExec] = { - var ret = conditionHandlerMap.get(condition) - if (ret.isEmpty) { - ret = conditionHandlerMap.get("UNKNOWN") - } - ret + conditionHandlerMap.get(condition) + .orElse(conditionHandlerMap.get("NOT FOUND") match { + case Some(handler) if condition.startsWith("02") => Some(handler) + case _ => None + }) + .orElse(conditionHandlerMap.get("UNKNOWN")) } /** From 736e481eb469c4296a11b2de0d6bf12c83a44527 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 9 Aug 2024 11:18:56 +0200 Subject: [PATCH 77/99] Add check for duplicate handler for the same sqlstate --- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 ++- .../sql/scripting/SqlScriptingExecutionNode.scala | 10 +++------- .../sql/scripting/SqlScriptingInterpreterSuite.scala | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2302e29d3ccc5..cb656913527eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -238,7 +238,8 @@ class AstBuilder extends DataTypeAstBuilder } override def visitConditionValue(ctx: ConditionValueContext): String = { - Option(ctx.multipartIdentifier()).map(_.getText).getOrElse(ctx.stringLit().getText) + Option(ctx.multipartIdentifier()).map(_.getText) + .getOrElse(ctx.stringLit().getText).replace("'", "") } override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index b4371d932ff1b..7206a9c6d4dcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -144,19 +144,15 @@ class SingleStatementExec( * SQL query text. */ def getText: String = { -// assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) - try { - origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) - } catch { - case e: Exception => - "DROP VARIABLE" - } + assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined) + origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1) } override def reset(): Unit = { raisedError = false errorState = None error = None + rethrow = None result = None // Should we do this? } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e019974174297..1df3f9973185a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -239,7 +239,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession | BEGIN | SET VAR flag = 1; | END; - | DECLARE CONTINUE HANDLER FOR zero_division + | DECLARE CONTINUE HANDLER FOR '22012' | BEGIN | SET VAR flag = 2; | END; From 3b6dca083ee7c122de4c3a87fca73dba19245ae8 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 9 Aug 2024 14:01:05 +0200 Subject: [PATCH 78/99] Add catch all handler --- .../scripting/SqlScriptingExecutionNode.scala | 4 +-- .../scripting/SqlScriptingInterpreter.scala | 16 ++++++++-- .../SqlScriptingInterpreterSuite.scala | 29 ++++++++++++++++++- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7206a9c6d4dcd..524a0eff29989 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -174,7 +174,7 @@ class SingleStatementExec( } case throwable: Throwable => raisedError = true - errorState = Some("UNKNOWN") + errorState = Some("SQLEXCEPTION") rethrow = Some(throwable) } } @@ -200,7 +200,7 @@ class CompoundBodyExec( case Some(handler) if condition.startsWith("02") => Some(handler) case _ => None }) - .orElse(conditionHandlerMap.get("UNKNOWN")) + .orElse(conditionHandlerMap.get("SQLEXCEPTION")) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 1c81ea259a53b..8c9e5157bf04c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.errors.SqlScriptingErrors + /** * SQL scripting interpreter - builds SQL script execution plan. */ @@ -97,12 +98,12 @@ case class SqlScriptingInterpreter(session: SparkSession) { if (isExitHandler) { val leave = new LeaveStatementExec(label) - val stmts = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ + val statements = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables :+ leave return new CompoundBodyExec( compoundBody.label, - stmts, + statements, conditionHandlerMap, session) } @@ -150,7 +151,16 @@ case class SqlScriptingInterpreter(session: SparkSession) { val executionPlan = buildExecutionPlan(compoundBody) executionPlan.flatMap { case statement: SingleStatementExec if statement.raisedError => - throw statement.rethrow.get + val sqlState = statement.errorState.getOrElse(throw statement.rethrow.get) + + // SQLWARNING and NOT FOUND are not considered as errors. + if (!sqlState.startsWith("01") || !sqlState.startsWith("02")) { + // Throw the error for SQLEXCEPTION. + throw statement.rethrow.get + } + + // Return empty result set for SQLWARNING and NOT FOUND. + None case statement: SingleStatementExec if statement.shouldCollectResult => statement.result case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 1df3f9973185a..74e5de30e3b23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -434,6 +434,33 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession verifySqlScriptResult(sqlScript, expected) } + test("handler - continue resolve by the CATCH ALL handler") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE CONTINUE HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT flag; + | SET VAR flag = 1; + | END; + | SELECT 2; + | SELECT 1/0; + | SELECT 3; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(3)), // select + Array(Row(1)), // select + ) + verifySqlScriptResult(sqlScript, expected) + } + test("chained begin end blocks") { val sqlScript = """ From 9675b3dcfb1a2d4695719723d0c47963098409e4 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 9 Aug 2024 16:47:38 +0200 Subject: [PATCH 79/99] Fix tests --- docs/sql-ref-ansi-compliance.md | 2 ++ .../sql/catalyst/parser/SqlScriptingParserSuite.scala | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 443bc8409efc9..98785e993ffe5 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -454,6 +454,7 @@ Below is a list of all the keywords in Spark SQL. |CONCATENATE|non-reserved|non-reserved|non-reserved| |CONSTRAINT|reserved|non-reserved|reserved| |CONTAINS|non-reserved|non-reserved|non-reserved| +|CONTINUE|non-reserved|non-reserved|non-reserved| |COST|non-reserved|non-reserved|non-reserved| |CREATE|reserved|non-reserved|reserved| |CROSS|reserved|strict-non-reserved|reserved| @@ -530,6 +531,7 @@ Below is a list of all the keywords in Spark SQL. |GRANT|reserved|non-reserved|reserved| |GROUP|reserved|non-reserved|reserved| |GROUPING|non-reserved|non-reserved|reserved| +|HANDLER|non-reserved|non-reserved|non-reserved| |HAVING|reserved|non-reserved|reserved| |HOUR|non-reserved|non-reserved|non-reserved| |HOURS|non-reserved|non-reserved|non-reserved| diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 77a50d521e609..4303fe9bcd591 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -357,21 +357,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | DECLARE test CONDITION; |END""".stripMargin val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ErrorCondition]) - assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("45000")) + assert(tree.conditions.size == 1) + assert(tree.conditions("test").equals("45000")) // Default SQLSTATE } test("declare condition: custom sqlstate") { val sqlScriptText = """ |BEGIN + | SELECT 1; | DECLARE test CONDITION FOR '12000'; |END""".stripMargin val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ErrorCondition]) - assert(tree.collection.head.asInstanceOf[ErrorCondition].value.equals("12000")) + assert(tree.conditions.size == 1) + assert(tree.conditions("test").equals("12000")) } test("declare handler") { From cbe2f0f1fae9fc5719e104da680d223ab525adf1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 11:04:06 +0200 Subject: [PATCH 80/99] Fix error throw condition from interpreter --- .../apache/spark/sql/scripting/SqlScriptingInterpreter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 8c9e5157bf04c..ae75375b2f237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -154,7 +154,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { val sqlState = statement.errorState.getOrElse(throw statement.rethrow.get) // SQLWARNING and NOT FOUND are not considered as errors. - if (!sqlState.startsWith("01") || !sqlState.startsWith("02")) { + if (!sqlState.startsWith("01") && !sqlState.startsWith("02")) { // Throw the error for SQLEXCEPTION. throw statement.rethrow.get } From 564f584947cc494c43e948af5fd8f857a1fc27e2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 11:27:45 +0200 Subject: [PATCH 81/99] Fix scalastyle --- .../spark/sql/catalyst/plans/QueryPlan.scala | 3 ++- .../SqlScriptingExecutionNodeSuite.scala | 2 +- .../SqlScriptingInterpreterSuite.scala | 20 +++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0d932e0038266..3c8b2ef2ac41c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -21,14 +21,15 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.rules.{RuleId, UnknownRuleId} -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag, TreePatternBits} +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.BitSet import java.util.IdentityHashMap + import scala.collection.mutable /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 9220e84615dd1..100a30470c4fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -139,7 +139,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("if else if - enter body of the ELSE IF clause") { - val iter = new TestBody(Seq( + val iter = TestBody(Seq( new IfElseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e1587a0723959..b350a195c7e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -131,7 +131,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession val expected = Seq( Array.empty[Row], // declare var Array.empty[Row], // set var - Array(Row(2)), // select + Array(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -148,7 +148,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession val expected = Seq( Array.empty[Row], // declare var Array.empty[Row], // set var - Array(Row(2)), // select + Array(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -179,7 +179,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(2)), // select Array.empty[Row], // declare var Array.empty[Row], // set var - Array(Row(4)), // select + Array(Row(4)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -225,7 +225,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array.empty[Row], // declare var Array.empty[Row], // set var Array(Row(2)), // select - Array.empty[Row], // drop var - explicit + Array.empty[Row] // drop var - explicit ) verifySqlScriptResult(sqlScript, expected) } @@ -281,7 +281,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(-1)), // select flag Array.empty[Row], // set flag Array(Row(4)), // select - Array(Row(1)), // select + Array(Row(1)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -321,7 +321,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(5)), // select Array(Row(6)), // select Array(Row(7)), // select - Array(Row(1)), // select + Array(Row(1)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -361,7 +361,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(5)), // select Array(Row(6)), // select Array(Row(7)), // select - Array(Row(1)), // select + Array(Row(1)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -392,7 +392,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(3)), // select Array(Row(-1)), // select flag Array.empty[Row], // set flag - Array(Row(1)), // select flag from the outer body + Array(Row(1)) // select flag from the outer body ) verifySqlScriptResult(sqlScript, expected) } @@ -430,7 +430,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array.empty[Row], // set flag // skip select 5 // skip select 6 - Array(Row(1)), // select flag from the outer body + Array(Row(1)) // select flag from the outer body ) verifySqlScriptResult(sqlScript, expected) } @@ -457,7 +457,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(-1)), // select flag Array.empty[Row], // set flag Array(Row(3)), // select - Array(Row(1)), // select + Array(Row(1)) // select ) verifySqlScriptResult(sqlScript, expected) } From 5c519c11e3465c4126729236d7f0e439958e1349 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 13:20:17 +0200 Subject: [PATCH 82/99] Add keywords --- .../org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 4 ++++ .../hive/thriftserver/ThriftServerWithSparkContextSuite.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 3e6b2e938d409..16f0d2da91fb2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1517,6 +1517,7 @@ ansiNonReserved | COMPUTE | CONCATENATE | CONTAINS + | CONTINUE | COST | CUBE | CURRENT @@ -1555,6 +1556,7 @@ ansiNonReserved | EXCHANGE | EXCLUDE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTENDED @@ -1842,6 +1844,7 @@ nonReserved | CONCATENATE | CONSTRAINT | CONTAINS + | CONTINUE | COST | CREATE | CUBE @@ -1890,6 +1893,7 @@ nonReserved | EXCLUDE | EXECUTE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTENDED diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index a5961b036871c..517dc26ada580 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From c8073591f83a630118a9971959cb61fccb9450f9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 13:42:18 +0200 Subject: [PATCH 83/99] Regenerate golden files --- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 8 ++++---- .../sql-tests/analyzer-results/ansi/literals.sql.out | 2 +- .../sql-tests/analyzer-results/literals.sql.out | 2 +- .../resources/sql-tests/results/ansi/keywords.sql.out | 10 ++++++++++ .../resources/sql-tests/results/ansi/literals.sql.out | 2 +- .../test/resources/sql-tests/results/keywords.sql.out | 11 ++++++++++- .../test/resources/sql-tests/results/literals.sql.out | 2 +- .../scripting/SqlScriptingExecutionNodeSuite.scala | 4 ++-- 8 files changed, 30 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3c8b2ef2ac41c..da6de40901499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.plans +import java.util.IdentityHashMap + +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ @@ -28,10 +32,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.BitSet -import java.util.IdentityHashMap - -import scala.collection.mutable - /** * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class * defines some basic properties of a query plan node, as well as some new transform APIs to diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out index 570cfb73444e5..738d27f0873e1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out @@ -239,7 +239,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42601", "messageParameters" : { - "error" : "'.'", + "error" : "end of input", "hint" : "" } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out index 570cfb73444e5..738d27f0873e1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out @@ -239,7 +239,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42601", "messageParameters" : { - "error" : "'.'", + "error" : "end of input", "hint" : "" } } diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index e03e0f0e3d638..14141615615e8 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -58,8 +58,10 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false +CONDITION true CONSTRAINT true CONTAINS false +CONTINUE false COST false CREATE true CROSS true @@ -110,6 +112,7 @@ EXCHANGE false EXCLUDE false EXECUTE true EXISTS false +EXIT false EXPLAIN false EXPORT false EXTENDED false @@ -127,6 +130,7 @@ FOR true FOREIGN true FORMAT false FORMATTED false +FOUND true FROM true FULL true FUNCTION false @@ -136,6 +140,7 @@ GLOBAL false GRANT true GROUP true GROUPING false +HANDLER true HAVING true HOUR false HOURS false @@ -287,6 +292,7 @@ SORTED false SOURCE false SPECIFIC false SQL true +SQLEXCEPTION true START false STATISTICS false STORED false @@ -376,6 +382,7 @@ CHECK COLLATE COLLATION COLUMN +CONDITION CONSTRAINT CREATE CROSS @@ -394,10 +401,12 @@ FETCH FILTER FOR FOREIGN +FOUND FROM FULL GRANT GROUP +HANDLER HAVING IN INNER @@ -425,6 +434,7 @@ SELECT SESSION_USER SOME SQL +SQLEXCEPTION TABLE THEN TIME diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index 4e4c70cc333ba..672d7f1567e87 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -269,7 +269,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42601", "messageParameters" : { - "error" : "'.'", + "error" : "end of input", "hint" : "" } } diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index e5a371925b1dc..5ab17c33de9da 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -58,8 +58,10 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false +CONDITION true CONSTRAINT false CONTAINS false +CONTINUE false COST false CREATE false CROSS false @@ -110,6 +112,7 @@ EXCHANGE false EXCLUDE false EXECUTE false EXISTS false +EXIT false EXPLAIN false EXPORT false EXTENDED false @@ -127,6 +130,7 @@ FOR false FOREIGN false FORMAT false FORMATTED false +FOUND true FROM false FULL false FUNCTION false @@ -136,6 +140,7 @@ GLOBAL false GRANT false GROUP false GROUPING false +HANDLER true HAVING false HOUR false HOURS false @@ -287,6 +292,7 @@ SORTED false SOURCE false SPECIFIC false SQL false +SQLEXCEPTION true START false STATISTICS false STORED false @@ -364,4 +370,7 @@ SELECT keyword from SQL_KEYWORDS() WHERE reserved -- !query schema struct -- !query output - +CONDITION +FOUND +HANDLER +SQLEXCEPTION diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 4e4c70cc333ba..672d7f1567e87 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -269,7 +269,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42601", "messageParameters" : { - "error" : "'.'", + "error" : "end of input", "hint" : "" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 100a30470c4fc..155887e95398b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} @@ -24,8 +26,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.test.SharedSparkSession -import scala.collection.mutable - /** * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. * Execution nodes are constructed manually and iterated through. From 0dddefca978d321970ed048b5dd72ff7e631fbc7 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 14:00:35 +0200 Subject: [PATCH 84/99] Revert QueryPlan to original state --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index da6de40901499..3f417644082c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -24,9 +24,11 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.{RuleId, UnknownRuleId} -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag, TreePatternBits} +import org.apache.spark.sql.catalyst.rules.RuleId +import org.apache.spark.sql.catalyst.rules.UnknownRuleId +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} From 5ad3474a9f9c6641a6668ddb0fb6ab878816bbd7 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 14:01:47 +0200 Subject: [PATCH 85/99] Remove logical operator for leave statement --- .../sql/catalyst/parser/SqlScriptingLogicalOperators.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index d483129815f0e..9ad2da7de0ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -89,12 +89,6 @@ case class ErrorHandler( body: CompoundBody, handlerType: HandlerType) extends CompoundPlanStatement -/** - * Logical operator for a leave statement. - * @param label Label of the CompoundBody leave statement should exit. - */ -case class BatchLeaveStatement(label: String) extends CompoundPlanStatement - /** * Logical operator for IF ELSE statement. * @param conditions Collection of conditions. First condition corresponds to IF clause, From e8d9506fabc467655f41d40bade1177977532205 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 15:48:11 +0200 Subject: [PATCH 86/99] Address comments v1 --- .../resources/error/error-conditions.json | 6 +++ .../sql/catalyst/parser/AstBuilder.scala | 21 ++++------ .../parser/SqlScriptingLogicalOperators.scala | 2 + .../spark/sql/errors/SqlScriptingErrors.scala | 10 +++++ .../scripting/SqlScriptingExecutionNode.scala | 40 +++++-------------- .../scripting/SqlScriptingInterpreter.scala | 25 +++++++----- 6 files changed, 52 insertions(+), 52 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 97bf2cb0bca97..6385065ffcaeb 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1128,6 +1128,12 @@ ], "sqlState" : "42614" }, + "DUPLICATE_CONDITION_NAME_FOR_DIFFERENT_SQL_STATE" : { + "message" : [ + "Found duplicate condition name for different SQL states. Please, remove one of them." + ], + "sqlState" : "42710" + }, "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE" : { "message" : [ "Found duplicate handlers for the same SQL state . Please, remove one of them." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 50d0387da3712..f3d5fd1a92242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -149,8 +149,10 @@ class AstBuilder extends DataTypeAstBuilder stmt match { case handler: ErrorHandler => handlers += handler case condition: ErrorCondition => - assert(!conditions.contains(condition.conditionName)) // Check for duplicate names. - assert(!sqlStates.contains(condition.value)) // Check for duplicate sqlStates. + if (conditions.contains(condition.conditionName)) { + throw SqlScriptingErrors.duplicateConditionNameForDifferentSqlState( + CurrentOrigin.get, condition.conditionName) + } conditions += condition.conditionName -> condition.value sqlStates += condition.value case s => buff += s @@ -229,11 +231,7 @@ class AstBuilder extends DataTypeAstBuilder .map { s => SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan]) }.getOrElse { - val stmt = Option(ctx.beginEndCompoundBlock()). - getOrElse(Option(ctx.declareHandler()). - getOrElse(Option(ctx.declareCondition()). - getOrElse(ctx.ifElseStatement()))) - visit(stmt).asInstanceOf[CompoundPlanStatement] + visitChildren(ctx).asInstanceOf[CompoundPlanStatement] } } @@ -245,15 +243,13 @@ class AstBuilder extends DataTypeAstBuilder override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { Option(ctx.SQLEXCEPTION()).map(_ => Seq("SQLEXCEPTION")).getOrElse { Option(ctx.NOT()).map(_ => Seq("NOT FOUND")).getOrElse { - val buff = ListBuffer[String]() - val seen = scala.collection.mutable.Set[String]() + val buff = scala.collection.mutable.Set[String]() ctx.conditionValues.forEach { conditionValue => val elem = visit(conditionValue).asInstanceOf[String] - if (seen(elem)) { + if (buff(elem)) { throw SqlScriptingErrors.duplicateSqlStateForSameHandler(CurrentOrigin.get, elem) } buff += elem - seen += elem } buff.toSeq } @@ -275,8 +271,7 @@ class AstBuilder extends DataTypeAstBuilder override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { val conditionName = ctx.multipartIdentifier().getText - val conditionValue = Option(ctx.stringLit()).map(_.getText).getOrElse("'45000'"). - replace("'", "") + val conditionValue = Option(ctx.stringLit()).map(_.getText.replace("'", "")).getOrElse("45000") val sqlStateRegex = "^[A-Za-z0-9]{5}$".r assert(sqlStateRegex.findFirstIn(conditionValue).isDefined) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 9ad2da7de0ba1..8da650fec7b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -57,6 +57,8 @@ case class SingleStatement(parsedPlan: LogicalPlan) * @param label Label set to CompoundBody by user or UUID otherwise. * It can be None in case when CompoundBody is not part of BeginEndCompoundBlock * for example when CompoundBody is inside loop or conditional block. + * @param handlers Collection of handlers defined in the compound body. + * @param conditions Map of Condition Name - Sql State values declared in the compound body. */ case class CompoundBody( collection: Seq[CompoundPlanStatement], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index b5602aa19c47d..a9e5638144f99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -68,6 +68,16 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("sqlState" -> sqlState)) } + def duplicateConditionNameForDifferentSqlState( + origin: Origin, + conditionName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_CONDITION_NAME_FOR_DIFFERENT_SQL_STATE", + cause = null, + messageParameters = Map("conditionName" -> conditionName)) + } + def variableDeclarationNotAllowedInScope( origin: Origin, varName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 17d4376fd8b9f..00e3fb15248da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -49,15 +49,7 @@ sealed trait CompoundStatementExec extends Logging { */ trait LeafStatementExec extends CompoundStatementExec { - /** - * Execute the statement. - * @param session Spark session. - */ - def execute(session: SparkSession): Unit - - /** - * Whether an error was raised during the execution of this statement. - */ + /** Whether an error was raised during the execution of this statement. */ var raisedError: Boolean = false /** @@ -65,15 +57,17 @@ trait LeafStatementExec extends CompoundStatementExec { */ var errorState: Option[String] = None - /** - * Error raised during statement execution. - */ + /** Error raised during statement execution. */ var error: Option[SparkThrowable] = None + /** Throwable to rethrow after the statement execution if the error is not handled. */ + var rethrow: Option[Throwable] = None + /** - * Throwable to rethrow after the statement execution if the error is not handled. + * Execute the statement. + * @param session Spark session. */ - var rethrow: Option[Throwable] = None + def execute(session: SparkSession): Unit } /** @@ -141,9 +135,7 @@ class SingleStatementExec( val shouldCollectResult: Boolean = false) extends LeafStatementExec with WithOrigin { - /** - * Data returned after execution. - */ + /** Data returned after execution. */ var result: Option[Array[Row]] = None /** @@ -164,7 +156,7 @@ class SingleStatementExec( result = None // Should we do this? } - def execute(session: SparkSession): Unit = { + override def execute(session: SparkSession): Unit = { try { val rows = Some(Dataset.ofRows(session, parsedPlan).collect()) if (shouldCollectResult) { @@ -340,18 +332,6 @@ class LeaveStatementExec(val label: String) extends LeafStatementExec { override def reset(): Unit = used = false } -/** - * Executable node for Continue statement. - */ -class ContinueStatementExec() extends LeafStatementExec { - - var used: Boolean = false - - override def execute(session: SparkSession): Unit = () - - override def reset(): Unit = used = false -} - /** * Executable node for IfElseStatement. * @param conditions Collection of executable conditions. First condition corresponds to IF clause, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index ae75375b2f237..fc746bd41992a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.scripting import scala.collection.mutable -import scala.collection.mutable.ListBuffer import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier @@ -39,12 +38,10 @@ case class SqlScriptingInterpreter(session: SparkSession) { * * @param compound * CompoundBody for which to build the plan. - * @param session - * Spark session that SQL script is executed within. * @return * Iterator through collection of statements to be executed. */ - def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { + private def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = { transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec].getTreeIterator } @@ -61,10 +58,22 @@ case class SqlScriptingInterpreter(session: SparkSession) { case _ => None } + /** + * Fetch the name of the Create Variable plan. + * @param compoundBody + * CompoundBody to be transformed into CompoundBodyExec. + * @param isExitHandler + * Flag to indicate if the body is an exit handler body to add leave statement at the end. + * @param exitHandlerLabel + * If body is an exit handler body, this is the label of surrounding CompoundBody + * that should be exited. + * @return + * Executable version of the CompoundBody . + */ private def transformBodyIntoExec( compoundBody: CompoundBody, isExitHandler: Boolean = false, - label: String = ""): CompoundBodyExec = { + exitHandlerLabel: String = ""): CompoundBodyExec = { val variables = compoundBody.collection.flatMap { case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) case _ => None @@ -75,7 +84,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { .reverse val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() - val handlers = ListBuffer[ErrorHandlerExec]() compoundBody.handlers.foreach(handler => { val handlerBodyExec = transformBodyIntoExec(handler.body, @@ -92,12 +100,11 @@ case class SqlScriptingInterpreter(session: SparkSession) { case None => conditionHandlerMap.put(conditionValue, handlerExec) } }) - - handlers += handlerExec }) if (isExitHandler) { - val leave = new LeaveStatementExec(label) + // Create leave statement to exit the surrounding CompoundBody after handler execution. + val leave = new LeaveStatementExec(exitHandlerLabel) val statements = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables :+ leave From dc7f52175f9f069213dcb3b39add4e3a6bc468f3 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 16:23:36 +0200 Subject: [PATCH 87/99] Address comments v2 --- .../scripting/SqlScriptingExecutionNode.scala | 31 ++++++++++++------- .../scripting/SqlScriptingInterpreter.scala | 28 +++++++++-------- .../SqlScriptingExecutionNodeSuite.scala | 2 +- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 00e3fb15248da..25f7fdda8e87a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -52,9 +52,7 @@ trait LeafStatementExec extends CompoundStatementExec { /** Whether an error was raised during the execution of this statement. */ var raisedError: Boolean = false - /** - * Error state of the statement. - */ + /** Error state of the statement. */ var errorState: Option[String] = None /** Error raised during statement execution. */ @@ -68,6 +66,13 @@ trait LeafStatementExec extends CompoundStatementExec { * @param session Spark session. */ def execute(session: SparkSession): Unit + + override def reset(): Unit = { + raisedError = false + errorState = None + error = None + rethrow = None + } } /** @@ -149,11 +154,8 @@ class SingleStatementExec( } override def reset(): Unit = { - raisedError = false - errorState = None - error = None - rethrow = None - result = None // Should we do this? + super.reset() + result = None } override def execute(session: SparkSession): Unit = { @@ -188,12 +190,17 @@ class SingleStatementExec( * Spark session. */ class CompoundBodyExec( - label: Option[String] = None, - statements: Seq[CompoundStatementExec], - conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(), - session: SparkSession) + statements: Seq[CompoundStatementExec], + session: SparkSession, + label: Option[String] = None, + conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap()) extends NonLeafStatementExec { + /** + * Get handler to handle error given by condition. + * @param condition SqlState of the error raised during statement execution. + * @return Corresponding error handler executable node. + */ private def getHandler(condition: String): Option[ErrorHandlerExec] = { conditionHandlerMap.get(condition) .orElse(conditionHandlerMap.get("NOT FOUND") match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index fc746bd41992a..badaa2f59e5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -83,21 +83,27 @@ case class SqlScriptingInterpreter(session: SparkSession) { .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse + // Create a map of conditions (SqlStates) to their respective handlers. val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]() compoundBody.handlers.foreach(handler => { val handlerBodyExec = transformBodyIntoExec(handler.body, handler.handlerType == HandlerType.EXIT, compoundBody.label.get) + + // Execution node of handler. val handlerExec = new ErrorHandlerExec(handlerBodyExec) + // For each condition handler is defined for, add corresponding key value pair + // to the conditionHandlerMap. handler.conditions.foreach(condition => { + // Condition can either be the key in conditions map or SqlState. val conditionValue = compoundBody.conditions.getOrElse(condition, condition) - conditionHandlerMap.get(conditionValue) match { - case Some(_) => - throw SqlScriptingErrors.duplicateHandlerForSameSqlState( - CurrentOrigin.get, conditionValue) - case None => conditionHandlerMap.put(conditionValue, handlerExec) + if (conditionHandlerMap.contains(conditionValue)) { + throw SqlScriptingErrors.duplicateHandlerForSameSqlState( + CurrentOrigin.get, conditionValue) + } else { + conditionHandlerMap.put(conditionValue, handlerExec) } }) }) @@ -108,18 +114,14 @@ case class SqlScriptingInterpreter(session: SparkSession) { val statements = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables :+ leave - return new CompoundBodyExec( - compoundBody.label, - statements, - conditionHandlerMap, - session) + return new CompoundBodyExec(statements, session, compoundBody.label, conditionHandlerMap) } new CompoundBodyExec( - compoundBody.label, compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++ dropVariables, - conditionHandlerMap, - session) + session, + compoundBody.label, + conditionHandlerMap) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 155887e95398b..63847e57df1c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -40,7 +40,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } case class TestBody(statements: Seq[CompoundStatementExec]) - extends CompoundBodyExec(None, statements, mutable.HashMap(), null) + extends CompoundBodyExec(statements, null, None, mutable.HashMap()) case class TestSparkStatementWithPlan(testVal: String) case class TestIfElseCondition(condVal: Boolean, description: String) From 29f1afb9512916d1848d55cedaf9000757facef2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 16:35:32 +0200 Subject: [PATCH 88/99] Make ErrorHandlerExec body param public --- .../sql/scripting/SqlScriptingExecutionNode.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 25f7fdda8e87a..297cf8fb74946 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -221,7 +221,7 @@ class CompoundBodyExec( statement.reset() // Clear all flags and result handler.reset() returnHere = curr - curr = Some(handler.getHandlerBody) + curr = Some(handler.body) } } statement @@ -240,7 +240,7 @@ class CompoundBodyExec( stopIteration = true // If label of the block matches the label of the leave statement, // mark the leave statement as used - if (label.getOrElse("").equals(leave.getLabel)) { + if (label.getOrElse("").equals(leave.label)) { leave.used = true } } @@ -314,12 +314,10 @@ class CompoundBodyExec( } } -class ErrorHandlerExec(body: CompoundBodyExec) extends NonLeafStatementExec { +class ErrorHandlerExec(val body: CompoundBodyExec) extends NonLeafStatementExec { override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator - def getHandlerBody: CompoundBodyExec = body - override def reset(): Unit = body.reset() } @@ -332,8 +330,6 @@ class LeaveStatementExec(val label: String) extends LeafStatementExec { var used: Boolean = false - def getLabel: String = label - override def execute(session: SparkSession): Unit = () override def reset(): Unit = used = false From 5389e0cb321f01ef02c5c078d2790daa665a09ce Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 16:40:33 +0200 Subject: [PATCH 89/99] Explain get handler logic and fix coding style --- .../sql/scripting/SqlScriptingExecutionNode.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 297cf8fb74946..45bb0c651e110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -203,11 +203,13 @@ class CompoundBodyExec( */ private def getHandler(condition: String): Option[ErrorHandlerExec] = { conditionHandlerMap.get(condition) - .orElse(conditionHandlerMap.get("NOT FOUND") match { - case Some(handler) if condition.startsWith("02") => Some(handler) - case _ => None - }) - .orElse(conditionHandlerMap.get("SQLEXCEPTION")) + .orElse{ + conditionHandlerMap.get("NOT FOUND") match { + // If NOT FOUND handler is defined, use it only for errors with class '02'. + case Some(handler) if condition.startsWith("02") => Some(handler) + case _ => None + }} + .orElse{conditionHandlerMap.get("SQLEXCEPTION")} } /** From 12bdedeae31fa376a40d29544f3cf7cd32700eef Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 17:23:59 +0200 Subject: [PATCH 90/99] Add more comments abou iterator execution --- .../sql/scripting/SqlScriptingExecutionNode.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 45bb0c651e110..695526b6bdcaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -55,9 +55,6 @@ trait LeafStatementExec extends CompoundStatementExec { /** Error state of the statement. */ var errorState: Option[String] = None - /** Error raised during statement execution. */ - var error: Option[SparkThrowable] = None - /** Throwable to rethrow after the statement execution if the error is not handled. */ var rethrow: Option[Throwable] = None @@ -168,7 +165,6 @@ class SingleStatementExec( case e: SparkThrowable => raisedError = true errorState = Some(e.getSqlState) - error = Some(e) e match { case throwable: Throwable => rethrow = Some(throwable) @@ -253,7 +249,12 @@ class CompoundBodyExec( private var localIterator: Iterator[CompoundStatementExec] = statements.iterator private var curr: Option[CompoundStatementExec] = if (localIterator.hasNext) Some(localIterator.next()) else None - private var stopIteration: Boolean = false // hard stop iteration flag + + // Flag to stop the iteration of the current begin/end block. + // It is set to true when non-consumed leave statement is encountered. + private var stopIteration: Boolean = false + + // Statement to return to after handling the error with continue handler. private var returnHere: Option[CompoundStatementExec] = None def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator From 2e0746150fe7c31afa507c51475434f3c5aff64e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 12 Aug 2024 18:09:00 +0200 Subject: [PATCH 91/99] Fix error --- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 695526b6bdcaf..2e17c7fdb80f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -67,7 +67,6 @@ trait LeafStatementExec extends CompoundStatementExec { override def reset(): Unit = { raisedError = false errorState = None - error = None rethrow = None } } From c0c2d5d656b04d4fe56128a95d8ce0251c7b5b9e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 10:10:50 +0200 Subject: [PATCH 92/99] Refactor label equality check for leave statement --- .../sql/scripting/SqlScriptingExecutionNode.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 2e17c7fdb80f7..7124100bfe0d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -236,9 +236,13 @@ class CompoundBodyExec( // Hard stop the iteration of the current begin/end block stopIteration = true // If label of the block matches the label of the leave statement, - // mark the leave statement as used - if (label.getOrElse("").equals(leave.label)) { - leave.used = true + // mark the leave statement as used. label can be None in case of a + // CompoundBody inside loop or if/else structure. In such cases, + // loop will have its own label to be matched by leave statement. + if (label.isDefined) { + leave.used = label.get.equals(leave.label) + } else { + leave.used = false } } curr = if (localIterator.hasNext) Some(localIterator.next()) else None From 7598e0fe9aff896f54856985c30f7aff8faa586c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 10:55:45 +0200 Subject: [PATCH 93/99] Add setStatementWithOptionalVarKeyword to the handler grammar and visitor --- .../sql/catalyst/parser/SqlBaseParser.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../SqlScriptingExecutionNodeSuite.scala | 18 +++--- .../SqlScriptingInterpreterSuite.scala | 62 +++++++++---------- 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 8259ec66befbb..f80a04b911c22 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -103,7 +103,7 @@ declareCondition ; declareHandler - : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement) + : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement | setStatementWithOptionalVarKeyword) ; beginLabel diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 14a822a5ae162..e6ad1161e1d5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -302,7 +302,7 @@ class AstBuilder extends DataTypeAstBuilder val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) val body = Option(ctx.compoundBody()).map(visit).getOrElse { - val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan] + val logicalPlan = visitChildren(ctx).asInstanceOf[LogicalPlan] CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan))) }.asInstanceOf[CompoundBody] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index dba144859d9ec..b86927f349550 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -268,10 +268,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("while - doesn't enter body") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( TestWhile( condition = TestWhileCondition(condVal = true, reps = 0, description = "con1"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + body = TestBody(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -279,10 +279,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("while - enters body once") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( TestWhile( condition = TestWhileCondition(condVal = true, reps = 1, description = "con1"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + body = TestBody(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -290,10 +290,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("while - enters body with multiple statements multiple times") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( TestWhile( condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), - body = new CompoundBodyExec(Seq( + body = TestBody(Seq( TestLeafStatement("statement1"), TestLeafStatement("statement2"))) ) @@ -304,13 +304,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("nested while - 2 times outer 2 times inner") { - val iter = new CompoundBodyExec(Seq( + val iter = TestBody(Seq( TestWhile( condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), - body = new CompoundBodyExec(Seq( + body = TestBody(Seq( TestWhile( condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + body = TestBody(Seq(TestLeafStatement("body1"))) )) ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8e6d2f7f3fda3..e2a5e90a4ed89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -696,14 +696,14 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i - Seq(Row(0)), // select i - Seq.empty[Row], // set i - Seq(Row(1)), // select i - Seq.empty[Row], // set i - Seq(Row(2)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop var + Array.empty[Row], // declare i + Array(Row(0)), // select i + Array.empty[Row], // set i + Array(Row(1)), // select i + Array.empty[Row], // set i + Array(Row(2)), // select i + Array.empty[Row], // set i + Array.empty[Row] // drop var ) verifySqlScriptResult(commands, expected) } @@ -721,8 +721,8 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i - Seq.empty[Row] // drop i + Array.empty[Row], // declare i + Array.empty[Row] // drop i ) verifySqlScriptResult(commands, expected) } @@ -745,22 +745,22 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i - Seq.empty[Row], // declare j - Seq.empty[Row], // set j to 0 - Seq(Row(0, 0)), // select i, j - Seq.empty[Row], // increase j - Seq(Row(0, 1)), // select i, j - Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // set j to 0 - Seq(Row(1, 0)), // select i, j - Seq.empty[Row], // increase j - Seq(Row(1, 1)), // select i, j - Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // drop j - Seq.empty[Row] // drop i + Array.empty[Row], // declare i + Array.empty[Row], // declare j + Array.empty[Row], // set j to 0 + Array(Row(0, 0)), // select i, j + Array.empty[Row], // increase j + Array(Row(0, 1)), // select i, j + Array.empty[Row], // increase j + Array.empty[Row], // increase i + Array.empty[Row], // set j to 0 + Array(Row(1, 0)), // select i, j + Array.empty[Row], // increase j + Array(Row(1, 1)), // select i, j + Array.empty[Row], // increase j + Array.empty[Row], // increase i + Array.empty[Row], // drop j + Array.empty[Row] // drop i ) verifySqlScriptResult(commands, expected) } @@ -779,11 +779,11 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( - Seq.empty[Row], // create table - Seq(Row(42)), // select - Seq.empty[Row], // insert - Seq(Row(42)), // select - Seq.empty[Row] // insert + Array.empty[Row], // create table + Array(Row(42)), // select + Array.empty[Row], // insert + Array(Row(42)), // select + Array.empty[Row] // insert ) verifySqlScriptResult(commands, expected) } From 3e3373ddb4b3272d1c7e107fad9fc6d69a5f44b9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 12:30:29 +0200 Subject: [PATCH 94/99] Add the rest of the keywords --- docs/sql-ref-ansi-compliance.md | 4 ++++ .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 ++++++++ .../sql-tests/results/ansi/keywords.sql.out | 12 ++++-------- .../resources/sql-tests/results/keywords.sql.out | 13 +++++-------- .../ThriftServerWithSparkContextSuite.scala | 2 +- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index a6b9051547a3f..60e6632715a98 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -452,6 +452,7 @@ Below is a list of all the keywords in Spark SQL. |COMPENSATION|non-reserved|non-reserved|non-reserved| |COMPUTE|non-reserved|non-reserved|non-reserved| |CONCATENATE|non-reserved|non-reserved|non-reserved| +|CONDITION|non-reserved|non-reserved|non-reserved| |CONSTRAINT|reserved|non-reserved|reserved| |CONTAINS|non-reserved|non-reserved|non-reserved| |CONTINUE|non-reserved|non-reserved|non-reserved| @@ -506,6 +507,7 @@ Below is a list of all the keywords in Spark SQL. |EXCLUDE|non-reserved|non-reserved|non-reserved| |EXECUTE|reserved|non-reserved|reserved| |EXISTS|non-reserved|non-reserved|reserved| +|EXIT|non-reserved|non-reserved|non-reserved| |EXPLAIN|non-reserved|non-reserved|non-reserved| |EXPORT|non-reserved|non-reserved|non-reserved| |EXTENDED|non-reserved|non-reserved|non-reserved| @@ -523,6 +525,7 @@ Below is a list of all the keywords in Spark SQL. |FOREIGN|reserved|non-reserved|reserved| |FORMAT|non-reserved|non-reserved|non-reserved| |FORMATTED|non-reserved|non-reserved|non-reserved| +|FOUND|non-reserved|non-reserved|non-reserved| |FROM|reserved|non-reserved|reserved| |FULL|reserved|strict-non-reserved|reserved| |FUNCTION|non-reserved|non-reserved|reserved| @@ -686,6 +689,7 @@ Below is a list of all the keywords in Spark SQL. |SOURCE|non-reserved|non-reserved|non-reserved| |SPECIFIC|non-reserved|non-reserved|reserved| |SQL|reserved|non-reserved|reserved| +|SQLEXCEPTION|non-reserved|non-reserved|non-reserved| |START|non-reserved|non-reserved|reserved| |STATISTICS|non-reserved|non-reserved|non-reserved| |STORED|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f80a04b911c22..7622601aa3559 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1521,6 +1521,7 @@ ansiNonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONTAINS | CONTINUE | COST @@ -1575,11 +1576,13 @@ ansiNonReserved | FOLLOWING | FORMAT | FORMATTED + | FOUND | FUNCTION | FUNCTIONS | GENERATED | GLOBAL | GROUPING + | HANDLER | HOUR | HOURS | IDENTIFIER_KW @@ -1705,6 +1708,7 @@ ansiNonReserved | SORTED | SOURCE | SPECIFIC + | SQLEXCEPTION | START | STATISTICS | STORED @@ -1849,6 +1853,7 @@ nonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONSTRAINT | CONTAINS | CONTINUE @@ -1919,6 +1924,7 @@ nonReserved | FOREIGN | FORMAT | FORMATTED + | FOUND | FROM | FUNCTION | FUNCTIONS @@ -1927,6 +1933,7 @@ nonReserved | GRANT | GROUP | GROUPING + | HANDLER | HAVING | HOUR | HOURS @@ -2070,6 +2077,7 @@ nonReserved | SOURCE | SPECIFIC | SQL + | SQLEXCEPTION | START | STATISTICS | STORED diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index ccb427b4f3706..fb014a09c1162 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -58,7 +58,7 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false -CONDITION true +CONDITION false CONSTRAINT true CONTAINS false CONTINUE false @@ -131,7 +131,7 @@ FOR true FOREIGN true FORMAT false FORMATTED false -FOUND true +FOUND false FROM true FULL true FUNCTION false @@ -141,7 +141,7 @@ GLOBAL false GRANT true GROUP true GROUPING false -HANDLER true +HANDLER false HAVING true HOUR false HOURS false @@ -293,7 +293,7 @@ SORTED false SOURCE false SPECIFIC false SQL true -SQLEXCEPTION true +SQLEXCEPTION false START false STATISTICS false STORED false @@ -384,7 +384,6 @@ CHECK COLLATE COLLATION COLUMN -CONDITION CONSTRAINT CREATE CROSS @@ -403,12 +402,10 @@ FETCH FILTER FOR FOREIGN -FOUND FROM FULL GRANT GROUP -HANDLER HAVING IN INNER @@ -436,7 +433,6 @@ SELECT SESSION_USER SOME SQL -SQLEXCEPTION TABLE THEN TIME diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 89af24505765e..e9430cf2a187e 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -58,7 +58,7 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false -CONDITION true +CONDITION false CONSTRAINT false CONTAINS false CONTINUE false @@ -131,7 +131,7 @@ FOR false FOREIGN false FORMAT false FORMATTED false -FOUND true +FOUND false FROM false FULL false FUNCTION false @@ -141,7 +141,7 @@ GLOBAL false GRANT false GROUP false GROUPING false -HANDLER true +HANDLER false HAVING false HOUR false HOURS false @@ -293,7 +293,7 @@ SORTED false SOURCE false SPECIFIC false SQL false -SQLEXCEPTION true +SQLEXCEPTION false START false STATISTICS false STORED false @@ -372,7 +372,4 @@ SELECT keyword from SQL_KEYWORDS() WHERE reserved -- !query schema struct -- !query output -CONDITION -FOUND -HANDLER -SQLEXCEPTION + diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 3b5f23c22d831..9327cbf916c33 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONDITION,CONSTRAINT,CONTAINS,CONTINUE,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXIT,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FOUND,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HANDLER,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,SQLEXCEPTION,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From bf9f409e1850310e8bc0c96425798f92762466cd Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 20:24:25 +0200 Subject: [PATCH 95/99] Make CONDITION reserved word --- docs/sql-ref-ansi-compliance.md | 2 +- .../apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 | 4 ---- .../apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 2 -- .../spark/sql/scripting/SqlScriptingInterpreter.scala | 2 +- .../resources/sql-tests/results/ansi/keywords.sql.out | 3 ++- .../test/resources/sql-tests/results/keywords.sql.out | 4 ++-- .../sql/scripting/SqlScriptingInterpreterSuite.scala | 10 +++------- 7 files changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 60e6632715a98..840279b3c64fa 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -452,7 +452,7 @@ Below is a list of all the keywords in Spark SQL. |COMPENSATION|non-reserved|non-reserved|non-reserved| |COMPUTE|non-reserved|non-reserved|non-reserved| |CONCATENATE|non-reserved|non-reserved|non-reserved| -|CONDITION|non-reserved|non-reserved|non-reserved| +|CONDITION|reserved|reserved|reserved| |CONSTRAINT|reserved|non-reserved|reserved| |CONTAINS|non-reserved|non-reserved|non-reserved| |CONTINUE|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index eb7e9c57ab75a..a717911f3843c 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -578,10 +578,6 @@ IDENTIFIER | UNICODE_LETTER+ '://' (UNICODE_LETTER | DIGIT | '_' | '/' | '-' | '.' | '?' | '=' | '&' | '#' | '%')+ ; -SQLSTATE - : '\'' (LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT)(LETTER | DIGIT) '\'' - ; - BACKQUOTED_IDENTIFIER : '`' ( ~'`' | '``' )* '`' ; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 7622601aa3559..3f90316f6163f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1521,7 +1521,6 @@ ansiNonReserved | COMPENSATION | COMPUTE | CONCATENATE - | CONDITION | CONTAINS | CONTINUE | COST @@ -1853,7 +1852,6 @@ nonReserved | COMPENSATION | COMPUTE | CONCATENATE - | CONDITION | CONSTRAINT | CONTAINS | CONTINUE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 4c00e30f6f61f..9e5f80fb784db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -59,7 +59,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { } /** - * Fetch the name of the Create Variable plan. + * Transform [[CompoundBody]] into [[CompoundBodyExec]]. * @param compoundBody * CompoundBody to be transformed into CompoundBodyExec. * @param isExitHandler diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index fb014a09c1162..c5fb9750fd23b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -58,7 +58,7 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false -CONDITION false +CONDITION true CONSTRAINT true CONTAINS false CONTINUE false @@ -384,6 +384,7 @@ CHECK COLLATE COLLATION COLUMN +CONDITION CONSTRAINT CREATE CROSS diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index e9430cf2a187e..f669298420272 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -58,7 +58,7 @@ COMPACTIONS false COMPENSATION false COMPUTE false CONCATENATE false -CONDITION false +CONDITION true CONSTRAINT false CONTAINS false CONTINUE false @@ -372,4 +372,4 @@ SELECT keyword from SQL_KEYWORDS() WHERE reserved -- !query schema struct -- !query output - +CONDITION diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e2a5e90a4ed89..7d87889e9da3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -702,8 +702,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array(Row(1)), // select i Array.empty[Row], // set i Array(Row(2)), // select i - Array.empty[Row], // set i - Array.empty[Row] // drop var + Array.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -721,8 +720,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession |""".stripMargin val expected = Seq( - Array.empty[Row], // declare i - Array.empty[Row] // drop i + Array.empty[Row] // declare i ) verifySqlScriptResult(commands, expected) } @@ -758,9 +756,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession Array.empty[Row], // increase j Array(Row(1, 1)), // select i, j Array.empty[Row], // increase j - Array.empty[Row], // increase i - Array.empty[Row], // drop j - Array.empty[Row] // drop i + Array.empty[Row] // increase i ) verifySqlScriptResult(commands, expected) } From ff0f330fc84f0955bb2dc4d63e7708f6bec27717 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 20:41:27 +0200 Subject: [PATCH 96/99] Add variable cleanup when executing leave statement --- .../scripting/SqlScriptingExecutionNode.scala | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 8348ff51cc0e7..b35de01511428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.scripting import scala.collection.mutable - import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.parser.SingleStatement import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -224,6 +224,18 @@ class CompoundBodyExec( statement } + /** + * Drop variables declared in this CompoundBody. + */ + private def cleanup(): Unit = { + // Filter out internal DropVariable statements and execute them. + statements.filter( + dropVar => dropVar.isInstanceOf[SingleStatementExec] + && dropVar.asInstanceOf[SingleStatementExec].parsedPlan.isInstanceOf[DropVariable] + && dropVar.isInternal) + .foreach(_.asInstanceOf[SingleStatementExec].execute(session)) + } + /** * Check if the leave statement was used, if it is not used stop iterating surrounding * [[CompoundBodyExec]] and move iterator forward. If the label of the block matches the label of @@ -233,8 +245,10 @@ class CompoundBodyExec( */ private def handleLeave(leave: LeaveStatementExec): LeaveStatementExec = { if (!leave.used) { - // Hard stop the iteration of the current begin/end block + // Hard stop the iteration of the current begin/end block. stopIteration = true + // Cleanup variables declared in the current block. + cleanup() // If label of the block matches the label of the leave statement, // mark the leave statement as used. label can be None in case of a // CompoundBody inside loop or if/else structure. In such cases, @@ -300,7 +314,7 @@ class CompoundBodyExec( handleLeave(leave) case leafStatement: LeafStatementExec => // This check is done to handle error in surrounding begin/end block - // if it was not handled in the nested block + // if it was not handled in the nested block. handleError(leafStatement) case nonLeafStatement: NonLeafStatementExec => nonLeafStatement } From a81a64ec53050fec4449e07c80a730b323f3681a Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 13 Aug 2024 20:46:52 +0200 Subject: [PATCH 97/99] Fix imports --- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index b35de01511428..cd315dffecdb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.scripting import scala.collection.mutable + import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.parser.SingleStatement import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} From 48c5929dea152297fb7a4fe4fa52950faa23b143 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 14 Aug 2024 13:38:37 +0200 Subject: [PATCH 98/99] Update doc --- docs/sql-ref-ansi-compliance.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 840279b3c64fa..ff9de533cf0aa 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -452,7 +452,7 @@ Below is a list of all the keywords in Spark SQL. |COMPENSATION|non-reserved|non-reserved|non-reserved| |COMPUTE|non-reserved|non-reserved|non-reserved| |CONCATENATE|non-reserved|non-reserved|non-reserved| -|CONDITION|reserved|reserved|reserved| +|CONDITION|reserved|non-reserved|reserved| |CONSTRAINT|reserved|non-reserved|reserved| |CONTAINS|non-reserved|non-reserved|non-reserved| |CONTINUE|non-reserved|non-reserved|non-reserved| From 97c0f5c9f4d1219bd990e37dcab34e6d94eba37d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 15 Aug 2024 13:48:52 +0200 Subject: [PATCH 99/99] Fix tests --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 3f90316f6163f..f851782021627 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1852,6 +1852,7 @@ nonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONSTRAINT | CONTAINS | CONTINUE