Skip to content

Commit

Permalink
[SPARK-49098][SQL] Add write options for INSERT
Browse files Browse the repository at this point in the history
  ### What changes were proposed in this pull request?

Add `tbl WITH (k1=v1, k2=v2)  for INSERT, INSERT OVERWRITE, and INSERT... REPLACE WHERE

  ### Why are the changes needed?

Follow up for SPARK-36680 which added WITH for SELECT statement

  ### Does this PR introduce _any_ user-facing change?

Adds new SQL syntax
  ### How was this patch tested?

New test in DataSourceV2SQLSuite
  ### Was this patch authored or co-authored using generative AI tooling?

No
  • Loading branch information
szehon-ho committed Aug 2, 2024
1 parent 78b83fa commit 5b23c51
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ query
;

insertInto
: INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable
| INSERT INTO TABLE? identifierReference partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable
| INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere
: INSERT OVERWRITE TABLE? identifierReference optionsClause? (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable
| INSERT INTO TABLE? identifierReference optionsClause? partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable
| INSERT INTO TABLE? identifierReference optionsClause? REPLACE whereClause #insertIntoReplaceWhere
| INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper

/**
* Parameters used for writing query to a table:
* (table ident, tableColumnList, partitionKeys, ifPartitionNotExists, byName).
* (table ident, options, tableColumnList, partitionKeys, ifPartitionNotExists, byName).
*/
type InsertTableParams =
(IdentifierReferenceContext, Seq[String], Map[String, Option[String]], Boolean, Boolean)
(IdentifierReferenceContext, Option[OptionsClauseContext], Seq[String],
Map[String, Option[String]], Boolean, Boolean)

/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
Expand All @@ -412,11 +413,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
// 2. Write commands do not hold the table logical plan as a child, and we need to add
// additional resolution code to resolve identifiers inside the write commands.
case table: InsertIntoTableContext =>
val (relationCtx, cols, partition, ifPartitionNotExists, byName)
val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
= visitInsertIntoTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
createUnresolvedRelation(relationCtx, ident, options),
partition,
cols,
query,
Expand All @@ -425,11 +426,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
byName)
})
case table: InsertOverwriteTableContext =>
val (relationCtx, cols, partition, ifPartitionNotExists, byName)
val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
createUnresolvedRelation(relationCtx, ident, options),
partition,
cols,
query,
Expand All @@ -440,7 +441,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
case ctx: InsertIntoReplaceWhereContext =>
withIdentClause(ctx.identifierReference, ident => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident),
createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())),
query,
expression(ctx.whereClause().booleanExpression()))
})
Expand Down Expand Up @@ -469,7 +470,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
invalidStatement("INSERT INTO ... IF NOT EXISTS", ctx)
}

(ctx.identifierReference, cols, partitionKeys, false, ctx.NAME() != null)
(ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, false,
ctx.NAME() != null)
}

/**
Expand All @@ -489,7 +491,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
dynamicPartitionKeys.keys.mkString(", "), ctx)
}

(ctx.identifierReference, cols, partitionKeys, ctx.EXISTS() != null, ctx.NAME() != null)
(ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys,
ctx.EXISTS() != null, ctx.NAME() != null)
}

/**
Expand Down Expand Up @@ -3067,9 +3070,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext,
optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) {
val options = optionsClause.map{ clause =>
new CaseInsensitiveStringMap(visitPropertyKeyValues(clause.options).asJava)
}.getOrElse(CaseInsensitiveStringMap.empty)
val options = resolveOptions(optionsClause)
withIdentClause(ctx, parts =>
new UnresolvedRelation(parts, options, isStreaming = false))
}
Expand All @@ -3078,8 +3079,18 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
* Create an [[UnresolvedRelation]] from a multi-part identifier.
*/
private def createUnresolvedRelation(
ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) {
UnresolvedRelation(ident)
ctx: ParserRuleContext,
ident: Seq[String],
optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
new UnresolvedRelation(ident, options, isStreaming = false)
}

private def resolveOptions(
optionsClause: Option[OptionsClauseContext]): CaseInsensitiveStringMap = {
optionsClause.map{ clause =>
new CaseInsensitiveStringMap(visitPropertyKeyValues(clause.options).asJava)
}.getOrElse(CaseInsensitiveStringMap.empty)
}

/**
Expand Down Expand Up @@ -4948,7 +4959,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options)
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None),
ident, isLazy, options)
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, ColumnStat, CommandResult, OverwriteByExpression}
import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _}
Expand All @@ -41,6 +41,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
Expand Down Expand Up @@ -3547,6 +3548,66 @@ class DataSourceV2SQLSuiteV1Filter
}
}

test("SPARK-36680: Supports Dynamic Table Options for Insert") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')")

val collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _) =>
assert(relation.options.get("write.split-size") == "10")
}
assert (collected.size == 1)

val insertResult = sql(s"SELECT * FROM $t1")
checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b")))
}
}

test("SPARK-36680: Supports Dynamic Table Options for Insert Overwrite") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')")

val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " +
s"VALUES (3, 'c'), (4, 'd')")
val collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_,
OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _),
_, _) =>
assert(relation.options.get("write.split-size") == "10")
}
assert (collected.size == 1)

val insertResult = sql(s"SELECT * FROM $t1")
checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
}
}

test("SPARK-36680: Supports Dynamic Table Options for Insert Replace") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')")

val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " +
s"REPLACE WHERE TRUE " +
s"VALUES (3, 'c'), (4, 'd')")
val collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_,
OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _),
_, _) =>
assert(relation.options.get("write.split-size") == "10")
}
assert (collected.size == 1)

val insertResult = sql(s"SELECT * FROM $t1")
checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
}
}

private def testNotSupportedV2Command(
sqlCommand: String,
sqlParams: String,
Expand Down

0 comments on commit 5b23c51

Please sign in to comment.