Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49269][SQL] Eagerly evaluate VALUES() list in AstBuilder #47791

Original file line number Diff line number Diff line change
Expand Up @@ -454,19 +454,14 @@ class AstBuilder extends DataTypeAstBuilder
val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
= visitInsertIntoTable(table)
withIdentClause(relationCtx, ident => {
val insertIntoStatement = InsertIntoStatement(
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists,
byName)
if (conf.getConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER)) {
EvaluateUnresolvedInlineTable.evaluate(insertIntoStatement)
} else {
insertIntoStatement
}
})
case table: InsertOverwriteTableContext =>
val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
Expand Down Expand Up @@ -1882,7 +1877,12 @@ class AstBuilder extends DataTypeAstBuilder
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}

val table = UnresolvedInlineTable(aliases, rows.toSeq)
val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq)
val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) {
EvaluateUnresolvedInlineTable.evaluate(unresolvedTable)
} else {
unresolvedTable
}
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,8 @@ import org.apache.spark.sql.types.{StructField, StructType}
object EvaluateUnresolvedInlineTable extends SQLConfHelper
with AliasHelper with EvalHelper with CastSupport {

def evaluate(plan: LogicalPlan): LogicalPlan = {
traversePlanAndEvalUnresolvedInlineTable(plan)
}

def traversePlanAndEvalUnresolvedInlineTable(plan: LogicalPlan): LogicalPlan = {
plan match {
case table: UnresolvedInlineTable if table.expressionsResolved =>
evaluateUnresolvedInlineTable(table)
case _ => plan.mapChildren(traversePlanAndEvalUnresolvedInlineTable)
}
}
def evaluate(plan: UnresolvedInlineTable): LogicalPlan =
if (plan.expressionsResolved) evaluateUnresolvedInlineTable(plan) else plan

def evaluateUnresolvedInlineTable(table: UnresolvedInlineTable): LogicalPlan = {
validateInputDimension(table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,11 +969,11 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val OPTIMIZE_INSERT_INTO_VALUES_PARSER =
buildConf("spark.sql.parser.optimizeInsertIntoValuesParser")
val EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED =
buildConf("spark.sql.parser.eagerEvalOfUnresolvedInlineTable")
.internal()
.doc("Controls whether we optimize the ASTree that gets generated when parsing " +
"`insert into ... values` DML statements.")
"VALUES lists (UnresolvedInlineTable) by eagerly evaluating it in the AST Builder.")
.booleanConf
.createWithDefault(true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2633,7 +2633,7 @@ class DDLParserSuite extends AnalysisTest {

for (optimizeInsertIntoValues <- Seq(true, false)) {
withSQLConf(
SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key ->
SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
optimizeInsertIntoValues.toString) {
comparePlans(parsePlan(dateTypeSql), insertPartitionPlan(
"2019-01-02", optimizeInsertIntoValues))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ package org.apache.spark.sql

import java.util.UUID

import org.apache.spark.sql.catalyst.analysis.UnresolvedInlineTable
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSession {

/**
* SQL parser.
*/
private lazy val parser = spark.sessionState.sqlParser

/**
* Generate a random table name.
*/
Expand Down Expand Up @@ -82,76 +89,117 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess
baseQuery + rows + ";"
}

/**
* Traverse the plan and check for the presence of the given node type.
*/
private def traversePlanAndCheckForNodeType[T <: LogicalPlan](
plan: LogicalPlan, nodeType: Class[T]): Boolean = plan match {
case node if nodeType.isInstance(node) => true
case node if node.children.isEmpty => false
case _ => plan.children.exists(traversePlanAndCheckForNodeType(_, nodeType))
}

/**
* Generate an INSERT INTO VALUES statement with both literals and expressions.
*/
private def generateInsertStatementsWithComplexExpressions(
tableName: String): String = {
s"""
INSERT INTO $tableName (id, first_name, last_name, age, gender,
email, phone_number, address, city, state, zip_code, country, registration_date) VALUES

(1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com',
concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA',
'2021-01-01'),
s"""
INSERT INTO $tableName (id, first_name, last_name, age, gender,
email, phone_number, address, city, state, zip_code, country, registration_date)
${generateValuesWithComplexExpressions}
"""
}

(2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St',
concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'),
/**
* Generate a VALUES clause with complex expressions.
*/
private def generateValuesWithComplexExpressions: String = {
s""" VALUES
(1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com',
concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA',
'2021-01-01'),

(3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234',
'123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'),
(2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St',
concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'),

(4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com',
'555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01');
"""
}
(3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234',
'123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'),

(4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com',
'555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01');
"""
}
test("Insert Into Values optimization - Basic literals.") {
// Set the number of inserted rows to 10000.
val rowCount = 10000
var firstTableName: Option[String] = None
Seq(true, false).foreach { insertIntoValueImprovementEnabled =>
Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled =>

// Create a table with a randomly generated name.
val tableName = createTable

// Set the feature flag for the InsertIntoValues improvement.
withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key ->
insertIntoValueImprovementEnabled.toString) {
withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
eagerEvalOfUnresolvedInlineTableEnabled.toString) {

// Generate an INSERT INTO VALUES statement.
val sqlStatement = generateInsertStatementWithLiterals(tableName, rowCount)

// Parse the SQL statement.
val plan = parser.parsePlan(sqlStatement)

// Traverse the plan and check for the presence of appropriate nodes depending on the
// feature flag.
if (eagerEvalOfUnresolvedInlineTableEnabled) {
assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation]))
} else {
assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable]))
}

spark.sql(sqlStatement)

// Double check that the insertion was successful.
val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect()
assert(countStar.head.getLong(0) == rowCount,
"The number of rows in the table should match the number of rows inserted.")
// Double check that the insertion was successful.
val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect()
assert(countStar.head.getLong(0) == rowCount,
"The number of rows in the table should match the number of rows inserted.")

// Check that both insertions will produce equivalent tables.
if (firstTableName.isEmpty) {
firstTableName = Some(tableName)
} else {
val df1 = spark.table(firstTableName.get)
val df2 = spark.table(tableName)
checkAnswer(df1, df2)
val df1 = spark.table(firstTableName.get)
val df2 = spark.table(tableName)
checkAnswer(df1, df2)
}
}
}
}

test("Insert Into Values optimization - Basic literals & expressions.") {
var firstTableName: Option[String] = None
Seq(true, false).foreach { insertIntoValueImprovementEnabled =>
Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled =>
// Create a table with a randomly generated name.
val tableName = createTable

// Set the feature flag for the InsertIntoValues improvement.
withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key ->
insertIntoValueImprovementEnabled.toString) {
withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
eagerEvalOfUnresolvedInlineTableEnabled.toString) {

// Generate an INSERT INTO VALUES statement.
val sqlStatement = generateInsertStatementsWithComplexExpressions(tableName)

// Parse the SQL statement.
val plan = parser.parsePlan(sqlStatement)

// Traverse the plan and check for the presence of appropriate nodes depending on the
// feature flag.
if (eagerEvalOfUnresolvedInlineTableEnabled) {
assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation]))
} else {
assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable]))
}

spark.sql(sqlStatement)

// Check that both insertions will produce equivalent tables.
Expand All @@ -168,17 +216,30 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess

test("Insert Into Values with defaults.") {
var firstTableName: Option[String] = None
Seq(true, false).foreach { insertIntoValueImprovementEnabled =>
Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled =>
// Create a table with default values specified.
val tableName = createTable

// Set the feature flag for the InsertIntoValues improvement.
withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key ->
insertIntoValueImprovementEnabled.toString) {
withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
eagerEvalOfUnresolvedInlineTableEnabled.toString) {

// Generate an INSERT INTO VALUES statement that omits all columns
// containing a DEFAULT value.
spark.sql(s"INSERT INTO $tableName (id) VALUES (1);")
val sqlStatement = s"INSERT INTO $tableName (id) VALUES (1);"

// Parse the SQL statement.
val plan = parser.parsePlan(sqlStatement)

// Traverse the plan and check for the presence of appropriate nodes depending on the
// feature flag.
if (eagerEvalOfUnresolvedInlineTableEnabled) {
assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation]))
} else {
assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable]))
}

spark.sql(sqlStatement)

// Verify that the default values are applied correctly.
val resultRow = spark.sql(
Expand Down Expand Up @@ -226,4 +287,72 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess
}
}
}

test("Value list in subquery") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a JIRA prefix in the test title

Suggested change
test("Value list in subquery") {
test("SPARK-49269: Value list in subquery") {

var firstDF: Option[DataFrame] = None
val flagVals = Seq(true, false)
flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled =>
// Set the feature flag for the InsertIntoValues improvement.
withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
eagerEvalOfUnresolvedInlineTableEnabled.toString) {

// Generate a subquery with a VALUES clause.
val sqlStatement = s"SELECT * FROM ($generateValuesWithComplexExpressions)"

// Parse the SQL statement.
val plan = parser.parsePlan(sqlStatement)

// Traverse the plan and check for the presence of appropriate nodes depending on the
// feature flag.
if (eagerEvalOfUnresolvedInlineTableEnabled) {
assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation]))
} else {
assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable]))
}

val res = spark.sql(sqlStatement)

// Check that both insertions will produce equivalent tables.
if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) {
firstDF = Some(res)
} else {
checkAnswer(res, firstDF.get)
}
}
}
}

test("Value list in projection list subquery") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

var firstDF: Option[DataFrame] = None
val flagVals = Seq(true, false)
flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled =>
// Set the feature flag for the InsertIntoValues improvement.
withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
eagerEvalOfUnresolvedInlineTableEnabled.toString) {

// Generate a subquery with a VALUES clause in the projection list.
val sqlStatement = s"SELECT (SELECT COUNT(*) FROM $generateValuesWithComplexExpressions)"

// Parse the SQL statement.
val plan = parser.parsePlan(sqlStatement)

// Traverse the plan and check for the presence of appropriate nodes depending on the
// feature flag.
if (eagerEvalOfUnresolvedInlineTableEnabled) {
assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation]))
} else {
assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable]))
}

val res = spark.sql(sqlStatement)

// Check that both insertions will produce equivalent tables.
if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) {
firstDF = Some(res)
} else {
checkAnswer(res, firstDF.get)
}
}
}
}
}