Skip to content

Commit

Permalink
[SPARK-44360][SQL] Support schema pruning in delta-based MERGE operat…
Browse files Browse the repository at this point in the history
…ions

### What changes were proposed in this pull request?

This PR adds support for schema pruning in delta-based MERGE operations.

### Why are the changes needed?

These changes are needed to improve the performance of certain row-level operations by skipping columns that are not required to materialize changes.

Consider an example.

```
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING")

sql(
  s"""MERGE INTO table t
     |USING source s
     |ON t.pk = s.pk
     |WHEN MATCHED AND t.salary = 200 THEN
     | UPDATE SET *
     |""".stripMargin)
```

In order to compute the new state of updated records we only need `pk` and `salary` columns from the target table. Hence, we can skip reading `country` and `dep` as values for those columns are coming from the source relation.

This logic does not apply to group-based MERGE operations as those have to copy over records and need values for all columns of the target table.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR comes with tests.

Closes apache#41930 from aokolnychyi/spark-44360.

Authored-by: aokolnychyi <aokolnychyi@apple.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
aokolnychyi authored and dongjoon-hyun committed Jul 11, 2023
1 parent 0f6a4a7 commit 37aa62f
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) =>
e.copy(child = prunedChild(child, e.references))

// prune unused columns from child of MergeRows for row-level operations
case e @ MergeRows(_, _, _, _, _, _, _, child) if !child.outputSet.subsetOf(e.references) =>
e.copy(child = prunedChild(child, e.references))

// prune unrequired references
case p @ Project(_, g: Generate) if p.references != g.outputSet =>
val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Instruction
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, ROW_ID}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.types.DataType
Expand All @@ -37,7 +37,17 @@ case class MergeRows(
AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
}

override lazy val references: AttributeSet = child.outputSet
@transient
override lazy val references: AttributeSet = {
val usedExprs = if (checkCardinality) {
val rowIdAttr = child.output.find(attr => conf.resolver(attr.name, ROW_ID))
assert(rowIdAttr.isDefined, "Cannot find row ID attr")
rowIdAttr.get +: expressions
} else {
expressions
}
AttributeSet.fromAttributeSets(usedExprs.map(_.references)) -- producedAttributes
}

override def simpleString(maxFields: Int): String = {
s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ case class MergeRowsExec(
AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
}

@transient override lazy val references: AttributeSet = child.outputSet
@transient
override lazy val references: AttributeSet = {
val usedExprs = if (checkCardinality) {
val rowIdAttr = child.output.find(attr => conf.resolver(attr.name, ROW_ID))
assert(rowIdAttr.isDefined, "Cannot find row ID attr")
rowIdAttr.get +: expressions
} else {
expressions
}
AttributeSet.fromAttributeSets(usedExprs.map(_.references)) -- producedAttributes
}

override def simpleString(maxFields: Int): String = {
s"MergeRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.connector

class DeltaBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase {
class DeltaBasedMergeIntoTableSuite extends DeltaBasedMergeIntoTableSuiteBase {

override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.types.StructType

abstract class DeltaBasedMergeIntoTableSuiteBase extends MergeIntoTableSuiteBase {

import testImplicits._

test("merge into schema pruning with WHEN MATCHED clause (update)") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "corrupted" }
|""".stripMargin)

val sourceRows = Seq(
(1, 100, "france", "software"),
(2, 200, "india", "finance"),
(3, 300, "china", "software"))
sourceRows.toDF("pk", "salary", "country", "dep").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED AND t.salary = 200 THEN
| UPDATE SET *
|""".stripMargin,
// `pk` is used in the SEARCH condition
// `salary` is used in the UPDATE condition
// `_partition` is used in the requested write distribution
expectedScanSchema = "pk INT, salary INT, _partition STRING")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 100, "uk", "hr"), // unchanged
Row(2, 200, "india", "finance"))) // update
}
}

test("merge into schema pruning with WHEN MATCHED clause (delete)") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "corrupted" }
|""".stripMargin)

Seq(1, 2, 3).toDF("pk").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED AND t.salary = 200 THEN
| DELETE
|""".stripMargin,
// `pk` is used in the SEARCH condition
// `salary` is used in the DELETE condition
// `_partition` is used in the requested write distribution
expectedScanSchema = "pk INT, salary INT, _partition STRING")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(Row(1, 100, "uk", "hr"))) // unchanged
}
}

test("merge into schema pruning with WHEN NOT MATCHED clause") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "software" }
|""".stripMargin)

val sourceRows = Seq(
(1, 100, "france", "software"),
(2, 200, "india", "finance"),
(3, 300, "china", "software"))
sourceRows.toDF("pk", "salary", "country", "dep").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN NOT MATCHED THEN
| INSERT *
|""".stripMargin,
// `pk` is used in the SEARCH condition
expectedScanSchema = "pk INT")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 100, "uk", "hr"), // unchanged
Row(2, 200, "us", "software"), // unchanged
Row(3, 300, "china", "software"))) // insert
}
}

test("merge into schema pruning with WHEN NOT MATCHED BY SOURCE clause (update)") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "software" }
|""".stripMargin)

val sourceRows = Seq(
(2, 200, "india", "finance"),
(3, 300, "china", "software"))
sourceRows.toDF("pk", "salary", "country", "dep").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN
| UPDATE SET country = 'invalid', dep = 'invalid'
|""".stripMargin,
// `pk` is used in the SEARCH condition
// `salary` is used in the UPDATE condition
// `_partition` is used in the requested write distribution
expectedScanSchema = "pk INT, salary INT, _partition STRING")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 100, "invalid", "invalid"), // update
Row(2, 200, "us", "software"))) // unchanged
}
}

test("merge into schema pruning with WHEN NOT MATCHED BY SOURCE clause (delete)") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "software" }
|""".stripMargin)

val sourceRows = Seq(
(2, 200, "india", "finance"),
(3, 300, "china", "software"))
sourceRows.toDF("pk", "salary", "country", "dep").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN
| DELETE
|""".stripMargin,
// `pk` is used in the SEARCH condition
// `salary` is used in the UPDATE condition
// `_partition` is used in the requested write distribution
expectedScanSchema = "pk INT, salary INT, _partition STRING")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(Row(2, 200, "us", "software"))) // unchanged
}
}

test("merge into schema pruning with clauses") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",
"""{ "pk": 1, "salary": 100, "country": "uk", "dep": "hr" }
|{ "pk": 2, "salary": 200, "country": "us", "dep": "software" }
|""".stripMargin)

val sourceRows = Seq(
(2, 200, "india", "finance"),
(3, 300, "china", "software"))
sourceRows.toDF("pk", "salary", "country", "dep").createOrReplaceTempView("source")

executeAndCheckScan(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED THEN
| UPDATE SET *
|WHEN NOT MATCHED THEN
| INSERT *
|WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN
| DELETE
|""".stripMargin,
// `pk` is used in the SEARCH condition
// `salary` is used in the DELETE condition
// `_partition` is used in the requested write distribution
expectedScanSchema = "pk INT, salary INT, _partition STRING")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(2, 200, "india", "finance"), // update
Row(3, 300, "china", "software"))) // insert
}
}

private def executeAndCheckScan(
query: String,
expectedScanSchema: String): Unit = {

val executedPlan = executeAndKeepPlan {
sql(query)
}

val scan = collect(executedPlan) {
case s: BatchScanExec => s
}.head
assert(DataTypeUtils.sameType(scan.schema, StructType.fromDDL(expectedScanSchema)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.connector

class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite extends MergeIntoTableSuiteBase {
class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite
extends DeltaBasedMergeIntoTableSuiteBase {

override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
Expand Down

0 comments on commit 37aa62f

Please sign in to comment.