From 37403b0f500d8de83a73efb7fcf617f77b5f461e Mon Sep 17 00:00:00 2001 From: Costas Zarifis Date: Wed, 21 Aug 2024 09:04:33 +0900 Subject: [PATCH] [SPARK-49269][SQL] Eagerly evaluate VALUES() list in AstBuilder ### What changes were proposed in this pull request? This is a continuation of a prior performance improvement: https://github.com/apache/spark/pull/47428 that eagerly evaluates memory-heavy `UnresolvedUnlineTables` parse tree nodes as soon as they are constructed in the AstBuilder. This PR applies this optimization to any statement that might contain one or more `VALUES()` clauses (such as subqueries etc), instead of just applying that optimization to `INSERT INTO ... VALUES` statements, which is what the prior PR did. ### Why are the changes needed? With these changes we not only reduce the memory footprint of every statement that can contain the `VALUES()` clause, but we also improve upon the previous optimization as we avoid unnecessary traversals of the parse tree, which not only improves the runtime performance, but also minimizes the amount of time in which the `UnresolvedInlineTable` is kept in memory. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Provided scala tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47791 from costas-db/eagerlyEvaluateUnresolvedInlineTableInAstBuilder. Authored-by: Costas Zarifis Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/parser/AstBuilder.scala | 14 +- .../util/EvaluateUnresolvedInlineTable.scala | 13 +- .../apache/spark/sql/internal/SQLConf.scala | 6 +- .../sql/catalyst/parser/DDLParserSuite.scala | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 52 +++-- .../analyzer-results/inline-table.sql.out | 4 +- .../postgreSQL/window_part4.sql.out | 2 +- .../udf/udf-inline-table.sql.out | 4 +- .../sql-tests/results/inline-table.sql.out | 4 +- .../results/postgreSQL/window_part4.sql.out | 2 +- .../results/udf/udf-inline-table.sql.out | 4 +- .../InlineTableParsingImprovementsSuite.scala | 217 +++++++++++++++--- .../command/DeclareVariableParserSuite.scala | 11 +- 13 files changed, 251 insertions(+), 84 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0bb4fc9c90d8a..038f15ee11035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -468,7 +468,7 @@ class AstBuilder extends DataTypeAstBuilder val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) = visitInsertIntoTable(table) withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { - val insertIntoStatement = InsertIntoStatement( + InsertIntoStatement( createUnresolvedRelation(relationCtx, ident, options), partition, cols, @@ -476,11 +476,6 @@ class AstBuilder extends DataTypeAstBuilder overwrite = false, ifPartitionNotExists, byName) - if (conf.getConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER)) { - EvaluateUnresolvedInlineTable.evaluate(insertIntoStatement) - } else { - insertIntoStatement - } }) case table: InsertOverwriteTableContext => val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) @@ -1897,7 +1892,12 @@ class AstBuilder extends DataTypeAstBuilder Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - val table = UnresolvedInlineTable(aliases, rows.toSeq) + val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) + val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) + } else { + unresolvedTable + } table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala index a55f70c238a8a..51cab6bff3b03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala @@ -35,17 +35,8 @@ import org.apache.spark.sql.types.{StructField, StructType} object EvaluateUnresolvedInlineTable extends SQLConfHelper with AliasHelper with EvalHelper with CastSupport { - def evaluate(plan: LogicalPlan): LogicalPlan = { - traversePlanAndEvalUnresolvedInlineTable(plan) - } - - def traversePlanAndEvalUnresolvedInlineTable(plan: LogicalPlan): LogicalPlan = { - plan match { - case table: UnresolvedInlineTable if table.expressionsResolved => - evaluateUnresolvedInlineTable(table) - case _ => plan.mapChildren(traversePlanAndEvalUnresolvedInlineTable) - } - } + def evaluate(plan: UnresolvedInlineTable): LogicalPlan = + if (plan.expressionsResolved) evaluateUnresolvedInlineTable(plan) else plan def evaluateUnresolvedInlineTable(table: UnresolvedInlineTable): LogicalPlan = { validateInputDimension(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 935111387745e..096cc974fbe6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -969,11 +969,11 @@ object SQLConf { .booleanConf .createWithDefault(true) - val OPTIMIZE_INSERT_INTO_VALUES_PARSER = - buildConf("spark.sql.parser.optimizeInsertIntoValuesParser") + val EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED = + buildConf("spark.sql.parser.eagerEvalOfUnresolvedInlineTable") .internal() .doc("Controls whether we optimize the ASTree that gets generated when parsing " + - "`insert into ... values` DML statements.") + "VALUES lists (UnresolvedInlineTable) by eagerly evaluating it in the AST Builder.") .booleanConf .createWithDefault(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 59602a4c77d08..c930292f2793c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2633,7 +2633,7 @@ class DDLParserSuite extends AnalysisTest { for (optimizeInsertIntoValues <- Seq(true, false)) { withSQLConf( - SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> + SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> optimizeInsertIntoValues.toString) { comparePlans(parsePlan(dateTypeSql), insertPartitionPlan( "2019-01-02", optimizeInsertIntoValues)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8d01040563361..e0217a5637a81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import scala.annotation.nowarn import org.apache.spark.SparkThrowable -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{EvaluateUnresolvedInlineTable, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedParameter, PosParameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -1000,14 +1000,28 @@ class PlanParserSuite extends AnalysisTest { } test("inline table") { - assertEqual("values 1, 2, 3, 4", - UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) + for (optimizeValues <- Seq(true, false)) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + optimizeValues.toString) { + val unresolvedTable1 = + UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x)))) + val table1 = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable1) + } else { + unresolvedTable1 + } + assertEqual("values 1, 2, 3, 4", table1) - assertEqual( - "values (1, 'a'), (2, 'b') as tbl(a, b)", - UnresolvedInlineTable( - Seq("a", "b"), - Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl")) + val unresolvedTable2 = UnresolvedInlineTable( + Seq("a", "b"), Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil) + val table2 = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable2) + } else { + unresolvedTable2 + } + assertEqual("values (1, 'a'), (2, 'b') as tbl(a, b)", table2.as("tbl")) + } + } } test("simple select query with !> and !<") { @@ -1907,12 +1921,22 @@ class PlanParserSuite extends AnalysisTest { } test("SPARK-42553: NonReserved keyword 'interval' can be column name") { - comparePlans( - parsePlan("SELECT interval FROM VALUES ('abc') AS tbl(interval);"), - UnresolvedInlineTable( - Seq("interval"), - Seq(Literal("abc")) :: Nil).as("tbl").select($"interval") - ) + for (optimizeValues <- Seq(true, false)) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + optimizeValues.toString) { + val unresolvedTable = + UnresolvedInlineTable(Seq("interval"), Seq(Literal("abc")) :: Nil) + val table = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) + } else { + unresolvedTable + } + comparePlans( + parsePlan("SELECT interval FROM VALUES ('abc') AS tbl(interval);"), + table.as("tbl").select($"interval") + ) + } + } } test("SPARK-44066: parsing of positional parameters") { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out index 988df7de1a3cf..78539effe188e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out @@ -115,7 +115,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -157,7 +157,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out index 2333cce874d31..f042116182f7d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out @@ -498,7 +498,7 @@ SELECT a, b, SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out index fb6130be5b6b4..786b5ac49b126 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out @@ -101,7 +101,7 @@ org.apache.spark.sql.AnalysisException -- !query select udf(a), udf(b) from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -143,7 +143,7 @@ org.apache.spark.sql.AnalysisException -- !query select udf(a), udf(b) from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4dcdf8ac3e980..0a2c7b0f55ed2 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -131,7 +131,7 @@ select * from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -177,7 +177,7 @@ select * from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 2085186dc8cfa..2d539725b2a70 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -497,7 +497,7 @@ FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out index d09f56a836788..3e84ec09c2150 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out @@ -115,7 +115,7 @@ select udf(a), udf(b) from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -161,7 +161,7 @@ select udf(a), udf(b) from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala index f305670dded8d..8c776874eaa1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala @@ -19,11 +19,19 @@ package org.apache.spark.sql import java.util.UUID +import org.apache.spark.sql.catalyst.analysis.UnresolvedInlineTable +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSession { + /** + * SQL parser. + */ + private lazy val parser = spark.sessionState.sqlParser + /** * Generate a random table name. */ @@ -59,7 +67,14 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess */ private def generateInsertStatementWithLiterals(tableName: String, numRows: Int): String = { val baseQuery = s"INSERT INTO $tableName (id, first_name, last_name, age, gender," + - s" email, phone_number, address, city, state, zip_code, country, registration_date) VALUES " + s" email, phone_number, address, city, state, zip_code, country, registration_date) " + baseQuery + generateValuesWithLiterals(numRows) + ";" + } + + /** + * Generate a VALUES clause with the given number of rows using basic literals. + */ + private def generateValuesWithLiterals(numRows: Int = 10): String = { val rows = (1 to numRows).map { i => val id = i val firstName = s"'FirstName_$id'" @@ -79,7 +94,33 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess s" $address, $city, $state, $zipCode, $country, $registrationDate)" }.mkString(",\n") - baseQuery + rows + ";" + s" VALUES $rows" + } + + /** + * Traverse the plan and check for the presence of the given node type. + */ + private def traversePlanAndCheckForNodeType[T <: LogicalPlan]( + plan: LogicalPlan, nodeType: Class[T]): Boolean = plan match { + case node if nodeType.isInstance(node) => true + case n: Project => + // If the plan node is a Project, we need to check the expressions in the project list + // and the child nodes. + n.projectList.exists(traverseExpressionAndCheckForNodeType(_, nodeType)) || + n.children.exists(traversePlanAndCheckForNodeType(_, nodeType)) + case node if node.children.isEmpty => false + case _ => plan.children.exists(traversePlanAndCheckForNodeType(_, nodeType)) + } + + /** + * Traverse the expression and check for the presence of the given node type. + */ + private def traverseExpressionAndCheckForNodeType[T <: LogicalPlan]( + expression: Expression, nodeType: Class[T]): Boolean = expression match { + case scalarSubquery: ScalarSubquery => scalarSubquery.plan.exists( + traversePlanAndCheckForNodeType(_, nodeType)) + case _ => + expression.children.exists(traverseExpressionAndCheckForNodeType(_, nodeType)) } /** @@ -87,54 +128,64 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess */ private def generateInsertStatementsWithComplexExpressions( tableName: String): String = { - s""" - INSERT INTO $tableName (id, first_name, last_name, age, gender, - email, phone_number, address, city, state, zip_code, country, registration_date) VALUES - - (1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com', - concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA', - '2021-01-01'), + s""" + INSERT INTO $tableName (id, first_name, last_name, age, gender, + email, phone_number, address, city, state, zip_code, country, registration_date) VALUES + (1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com', + concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA', + '2021-01-01'), - (2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St', - concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'), + (2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St', + concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'), - (3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234', - '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'), - - (4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com', - '555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'); - """ - } + (3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234', + '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'), + (4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com', + '555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'); + """ + } test("Insert Into Values optimization - Basic literals.") { - // Set the number of inserted rows to 10000. - val rowCount = 10000 + // Set the number of inserted rows to 10. + val rowCount = 10 var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with a randomly generated name. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement. val sqlStatement = generateInsertStatementWithLiterals(tableName, rowCount) + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + spark.sql(sqlStatement) - // Double check that the insertion was successful. - val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect() - assert(countStar.head.getLong(0) == rowCount, - "The number of rows in the table should match the number of rows inserted.") + // Double check that the insertion was successful. + val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect() + assert(countStar.head.getLong(0) == rowCount, + "The number of rows in the table should match the number of rows inserted.") // Check that both insertions will produce equivalent tables. if (firstTableName.isEmpty) { firstTableName = Some(tableName) } else { - val df1 = spark.table(firstTableName.get) - val df2 = spark.table(tableName) - checkAnswer(df1, df2) + val df1 = spark.table(firstTableName.get) + val df2 = spark.table(tableName) + checkAnswer(df1, df2) } } } @@ -142,16 +193,27 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess test("Insert Into Values optimization - Basic literals & expressions.") { var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with a randomly generated name. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement. val sqlStatement = generateInsertStatementsWithComplexExpressions(tableName) + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes. + // In this case, the plan should always contain a UnresolvedInlineTable node + // because the expressions are not eagerly resolved, therefore + // `plan.expressionsResolved` in `EvaluateUnresolvedInlineTable.evaluate` will + // always be false. + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + spark.sql(sqlStatement) // Check that both insertions will produce equivalent tables. @@ -168,17 +230,30 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess test("Insert Into Values with defaults.") { var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with default values specified. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement that omits all columns // containing a DEFAULT value. - spark.sql(s"INSERT INTO $tableName (id) VALUES (1);") + val sqlStatement = s"INSERT INTO $tableName (id) VALUES (1);" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + spark.sql(sqlStatement) // Verify that the default values are applied correctly. val resultRow = spark.sql( @@ -226,4 +301,72 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess } } } + + test("SPARK-49269: Value list in subquery") { + var firstDF: Option[DataFrame] = None + val flagVals = Seq(true, false) + flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled => + // Set the feature flag for the InsertIntoValues improvement. + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { + + // Generate a subquery with a VALUES clause. + val sqlStatement = s"SELECT * FROM (${generateValuesWithLiterals()});" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + val res = spark.sql(sqlStatement) + + // Check that both insertions will produce equivalent tables. + if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) { + firstDF = Some(res) + } else { + checkAnswer(res, firstDF.get) + } + } + } + } + + test("SPARK-49269: Value list in projection list subquery") { + var firstDF: Option[DataFrame] = None + val flagVals = Seq(true, false) + flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled => + // Set the feature flag for the InsertIntoValues improvement. + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { + + // Generate a subquery with a VALUES clause in the projection list. + val sqlStatement = s"SELECT (SELECT COUNT(*) FROM ${generateValuesWithLiterals()});" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + val res = spark.sql(sqlStatement) + + // Check that both insertions will produce equivalent tables. + if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) { + firstDF = Some(res) + } else { + checkAnswer(res, firstDF.get) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala index a292afe6a7c28..bc42937b93a92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedIdentifier, UnresolvedInlineTable} import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Divide, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, Project, SubqueryAlias} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType, IntegerType, MapType, NullType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -91,6 +93,13 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { Cast(UnresolvedFunction("CURRENT_DATABASE", Nil, isDistinct = false), StringType), "CURRENT_DATABASE()"), replace = false)) + val subqueryAliasChild = + if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + EvaluateUnresolvedInlineTable.evaluate( + UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil)) + } else { + UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil) + } comparePlans( parsePlan("DECLARE VARIABLE var1 INT DEFAULT (SELECT c1 FROM VALUES(1) AS T(c1))"), CreateVariable( @@ -99,7 +108,7 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { Cast(ScalarSubquery( Project(UnresolvedAttribute("c1") :: Nil, SubqueryAlias(Seq("T"), - UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil)))), IntegerType), + subqueryAliasChild))), IntegerType), "(SELECT c1 FROM VALUES(1) AS T(c1))"), replace = false)) }