Skip to content

Commit

Permalink
[SPARK-43963][SQL] DataSource V2: Handle MERGE commands for group-bas…
Browse files Browse the repository at this point in the history
…ed sources

### What changes were proposed in this pull request?

This PR adds support for group-based data sources in `RewriteMergeIntoTable`. This PR builds on top of PR #41448 and earlier PRs that added `RewriteDeleteFromTable`.

### Why are the changes needed?

These changes are needed per SPIP SPARK-35801.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR comes with tests. There are more tests in `AlignMergeAssignmentsSuite`, which was merged earlier.

Closes #41577 from aokolnychyi/spark-43963.

Authored-by: aokolnychyi <aokolnychyi@apple.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
aokolnychyi authored and dongjoon-hyun committed Jun 14, 2023
1 parent 3adbce2 commit 305506e
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, cond, remainingRowsPlan, relation)
ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond))
}

// build a rewrite plan for sources that support row deltas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Expression, IsNotNull, MetadataAttribute, MonotonicallyIncreasingID, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Exists, Expression, IsNotNull, Literal, MetadataAttribute, MonotonicallyIncreasingID, OuterReference, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
Expand Down Expand Up @@ -123,14 +123,121 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions)
case _ =>
throw new AnalysisException("Group-based MERGE commands are not supported yet")
buildReplaceDataPlan(
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions)
}

case _ =>
m
}
}

// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
private def buildReplaceDataPlan(
relation: DataSourceV2Relation,
operationTable: RowLevelOperationTable,
source: LogicalPlan,
cond: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): ReplaceData = {

// resolve all required metadata attrs that may be used for grouping data on write
// for instance, JDBC data source may cluster data by shard/host before writing
val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation)

// construct a read relation and include all required metadata columns
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)

val checkCardinality = shouldCheckCardinality(matchedActions)

// use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded
// use full outer join in all other cases, unmatched source rows may be needed
val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
val joinPlan = join(readRelation, source, joinType, cond, checkCardinality)

val mergeRowsPlan = buildReplaceDataMergeRowsPlan(
readRelation, joinPlan, matchedActions, notMatchedActions,
notMatchedBySourceActions, metadataAttrs, checkCardinality)

// predicates of the ON condition can be used to filter the target table (planning & runtime)
// only if there is no NOT MATCHED BY SOURCE clause
val (pushableCond, groupFilterCond) = if (notMatchedBySourceActions.isEmpty) {
(cond, Some(toGroupFilterCondition(relation, source, cond)))
} else {
(TrueLiteral, None)
}

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond)
}

private def buildReplaceDataMergeRowsPlan(
targetTable: LogicalPlan,
joinPlan: LogicalPlan,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction],
metadataAttrs: Seq[Attribute],
checkCardinality: Boolean): MergeRows = {

// target records that were read but did not match any MATCHED or NOT MATCHED BY SOURCE actions
// must be copied over and included in the new state of the table as groups are being replaced
// that's why an extra unconditional instruction that would produce the original row is added
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output)

val matchedInstructions = matchedActions.map { action =>
toInstruction(action, metadataAttrs)
} :+ keepCarryoverRowsInstruction

val notMatchedInstructions = notMatchedActions.map { action =>
toInstruction(action, metadataAttrs)
}

val notMatchedBySourceInstructions = notMatchedBySourceActions.map { action =>
toInstruction(action, metadataAttrs)
} :+ keepCarryoverRowsInstruction

val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE, joinPlan)
val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET, joinPlan)

val outputs = matchedInstructions.flatMap(_.outputs) ++
notMatchedInstructions.flatMap(_.outputs) ++
notMatchedBySourceInstructions.flatMap(_.outputs)

val attrs = targetTable.output

MergeRows(
isSourceRowPresent = IsNotNull(rowFromSourceAttr),
isTargetRowPresent = IsNotNull(rowFromTargetAttr),
matchedInstructions = matchedInstructions,
notMatchedInstructions = notMatchedInstructions,
notMatchedBySourceInstructions = notMatchedBySourceInstructions,
checkCardinality = checkCardinality,
output = generateExpandOutput(attrs, outputs),
joinPlan)
}

// converts a MERGE condition into an EXISTS subquery for runtime filtering
private def toGroupFilterCondition(
relation: DataSourceV2Relation,
source: LogicalPlan,
cond: Expression): Expression = {

val condWithOuterRefs = cond transformUp {
case attr: Attribute if relation.outputSet.contains(attr) => OuterReference(attr)
case other => other
}
val outerRefs = condWithOuterRefs.collect {
case OuterReference(e) => e
}
Exists(Filter(condWithOuterRefs, source), outerRefs)
}

// build a rewrite plan for sources that support row deltas
private def buildWriteDeltaPlan(
relation: DataSourceV2Relation,
Expand Down Expand Up @@ -310,6 +417,26 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
}
}

// converts a MERGE action into an instruction on top of the joined plan for group-based plans
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
case UpdateAction(cond, assignments) =>
val output = assignments.map(_.value) ++ metadataAttrs
Keep(cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
Discard(cond.getOrElse(TrueLiteral))

case InsertAction(cond, assignments) =>
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
val output = assignments.map(_.value) ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)

case other =>
throw new AnalysisException(s"Unexpected action: $other")
}
}

// converts a MERGE action into an instruction on top of the joined plan for delta-based plans
private def toInstruction(
action: MergeAction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition = replaceNullWithFalse(cond))
case rd @ ReplaceData(_, cond, _, _, groupFilterCond, _) =>
val newCond = replaceNullWithFalse(cond)
val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse)
rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond)
case wd @ WriteDelta(_, cond, _, _, _, _) => wd.copy(condition = replaceNullWithFalse(cond))
case d @ DeleteFromTable(_, cond) => d.copy(condition = replaceNullWithFalse(cond))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,18 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre
* This class extracts the following entities:
* - the group-based rewrite plan;
* - the condition that defines matching groups;
* - the group filter condition;
* - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]]
* depending on whether the planning has already happened;
*/
object GroupBasedRowLevelOperation {
type ReturnType = (ReplaceData, Expression, LogicalPlan)
type ReturnType = (ReplaceData, Expression, Option[Expression], LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond, query, _, _) =>
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _),
cond, query, _, groupFilterCond, _) =>
val readRelation = findReadRelation(table, query)
readRelation.map((rd, cond, _))
readRelation.map((rd, cond, groupFilterCond, _))

case _ =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Instruction
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -76,6 +77,15 @@ object MergeRows {
}
}

case class Discard(condition: Expression) extends Instruction with UnaryLike[Expression] {
override def outputs: Seq[Seq[Expression]] = Seq.empty
override def child: Expression = condition

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(condition = newChild)
}
}

case class Split(
condition: Expression,
output: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,15 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery {
* @param condition a condition that defines matching groups
* @param query a query with records that should replace the records that were read
* @param originalTable a plan for the original table for which the row-level command was triggered
* @param groupFilterCondition a condition that can be used to filter groups at runtime
* @param write a logical write, if already constructed
*/
case class ReplaceData(
table: NamedRelation,
condition: Expression,
query: LogicalPlan,
originalTable: NamedRelation,
groupFilterCondition: Option[Expression] = None,
write: Option[Write] = None) extends RowLevelWrite {

override val isByName: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
throw SparkException.internalError("Unexpected table relation: " + other)
}

case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, Some(write)) =>
case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, _, Some(write)) =>
// use the original relation to refresh the cache
ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
// push down the filter from the command condition instead of the filter in the rewrite plan,
// which is negated for data sources that only support replacing groups of data (e.g. files)
case GroupBasedRowLevelOperation(rd: ReplaceData, cond, relation: DataSourceV2Relation) =>
case GroupBasedRowLevelOperation(rd: ReplaceData, cond, _, relation: DataSourceV2Relation) =>
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -100,6 +100,9 @@ case class MergeRowsExec(
case Keep(cond, output) =>
KeepExec(createPredicate(cond), createProjection(output))

case Discard(cond) =>
DiscardExec(createPredicate(cond))

case Split(cond, output, otherOutput) =>
SplitExec(createPredicate(cond), createProjection(output), createProjection(otherOutput))

Expand All @@ -116,6 +119,8 @@ case class MergeRowsExec(
def apply(row: InternalRow): InternalRow = projection.apply(row)
}

case class DiscardExec(condition: BasePredicate) extends InstructionExec

case class SplitExec(
condition: BasePredicate,
projection: Projection,
Expand Down Expand Up @@ -206,6 +211,9 @@ case class MergeRowsExec(
case keep: KeepExec =>
return keep.apply(row)

case _: DiscardExec =>
return null

case split: SplitExec =>
cachedExtraRow = split.projectExtraRow(row)
return split.projectRow(row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic
type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case rd @ ReplaceData(_, cond, _, originalTable, _) =>
case rd @ ReplaceData(_, cond, _, originalTable, _, _) =>
val command = rd.operation.command
Some(rd, command, cond, originalTable)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)

case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) =>
case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) =>
val rowSchema = StructType.fromAttributes(rd.dataInput)
val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
val write = writeBuilder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

package org.apache.spark.sql.execution.dynamicpruning

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery
import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation}

/**
Expand All @@ -44,7 +48,7 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
// apply special dynamic filtering only for group-based row-level operations
case GroupBasedRowLevelOperation(replaceData, cond,
case GroupBasedRowLevelOperation(replaceData, _, Some(cond),
DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _))
if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral =>

Expand All @@ -55,7 +59,8 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla
// in order to leverage a regular batch scan in the group filter query
val originalTable = r.relation.table.asRowLevelOperationTable.table
val relation = r.relation.copy(table = originalTable)
val matchingRowsPlan = buildMatchingRowsPlan(relation, cond)
val command = replaceData.operation.command
val matchingRowsPlan = buildMatchingRowsPlan(relation, cond, command)

val filterAttrs = scan.filterAttributes
val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan)
Expand All @@ -71,9 +76,19 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla

private def buildMatchingRowsPlan(
relation: DataSourceV2Relation,
cond: Expression): LogicalPlan = {
cond: Expression,
command: Command): LogicalPlan = {

val matchingRowsPlan = Filter(cond, relation)
val matchingRowsPlan = command match {
case DELETE =>
Filter(cond, relation)
case UPDATE =>
throw new AnalysisException("Group-based UPDATE operations are currently not supported")
case MERGE =>
// rewrite the group filter subquery as joins
val filter = Filter(cond, relation)
RewritePredicateSubquery(filter)
}

// clone the relation and assign new expr IDs to avoid conflicts
matchingRowsPlan transformUpWithNewOutput {
Expand Down
Loading

0 comments on commit 305506e

Please sign in to comment.