Skip to content

Commit

Permalink
[SPARK-43963][SQL] DataSource V2: Handle MERGE commands for group-bas…
Browse files Browse the repository at this point in the history
…ed sources

This PR adds support for group-based data sources in `RewriteMergeIntoTable`. This PR builds on top of PR apache#41448 and earlier PRs that added `RewriteDeleteFromTable`.

These changes are needed per SPIP SPARK-35801.

No.

This PR comes with tests. There are more tests in `AlignMergeAssignmentsSuite`, which was merged earlier.

Closes apache#41577 from aokolnychyi/spark-43963.

Authored-by: aokolnychyi <aokolnychyi@apple.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit 305506e)
  • Loading branch information
aokolnychyi authored and huaxingao committed Jun 30, 2023
1 parent 05879f2 commit c2006f7
Show file tree
Hide file tree
Showing 16 changed files with 583 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, cond, remainingRowsPlan, relation)
ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond))
}

// build a rewrite plan for sources that support row deltas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Expression, IsNotNull, MetadataAttribute, MonotonicallyIncreasingID, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Exists, Expression, IsNotNull, Literal, MetadataAttribute, MonotonicallyIncreasingID, OuterReference, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
Expand Down Expand Up @@ -123,14 +123,121 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions)
case _ =>
throw new AnalysisException("Group-based MERGE commands are not supported yet")
buildReplaceDataPlan(
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions)
}

case _ =>
m
}
}

// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
private def buildReplaceDataPlan(
relation: DataSourceV2Relation,
operationTable: RowLevelOperationTable,
source: LogicalPlan,
cond: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): ReplaceData = {

// resolve all required metadata attrs that may be used for grouping data on write
// for instance, JDBC data source may cluster data by shard/host before writing
val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation)

// construct a read relation and include all required metadata columns
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)

val checkCardinality = shouldCheckCardinality(matchedActions)

// use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded
// use full outer join in all other cases, unmatched source rows may be needed
val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
val joinPlan = join(readRelation, source, joinType, cond, checkCardinality)

val mergeRowsPlan = buildReplaceDataMergeRowsPlan(
readRelation, joinPlan, matchedActions, notMatchedActions,
notMatchedBySourceActions, metadataAttrs, checkCardinality)

// predicates of the ON condition can be used to filter the target table (planning & runtime)
// only if there is no NOT MATCHED BY SOURCE clause
val (pushableCond, groupFilterCond) = if (notMatchedBySourceActions.isEmpty) {
(cond, Some(toGroupFilterCondition(relation, source, cond)))
} else {
(TrueLiteral, None)
}

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond)
}

private def buildReplaceDataMergeRowsPlan(
targetTable: LogicalPlan,
joinPlan: LogicalPlan,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction],
metadataAttrs: Seq[Attribute],
checkCardinality: Boolean): MergeRows = {

// target records that were read but did not match any MATCHED or NOT MATCHED BY SOURCE actions
// must be copied over and included in the new state of the table as groups are being replaced
// that's why an extra unconditional instruction that would produce the original row is added
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output)

val matchedInstructions = matchedActions.map { action =>
toInstruction(action, metadataAttrs)
} :+ keepCarryoverRowsInstruction

val notMatchedInstructions = notMatchedActions.map { action =>
toInstruction(action, metadataAttrs)
}

val notMatchedBySourceInstructions = notMatchedBySourceActions.map { action =>
toInstruction(action, metadataAttrs)
} :+ keepCarryoverRowsInstruction

val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE, joinPlan)
val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET, joinPlan)

val outputs = matchedInstructions.flatMap(_.outputs) ++
notMatchedInstructions.flatMap(_.outputs) ++
notMatchedBySourceInstructions.flatMap(_.outputs)

val attrs = targetTable.output

MergeRows(
isSourceRowPresent = IsNotNull(rowFromSourceAttr),
isTargetRowPresent = IsNotNull(rowFromTargetAttr),
matchedInstructions = matchedInstructions,
notMatchedInstructions = notMatchedInstructions,
notMatchedBySourceInstructions = notMatchedBySourceInstructions,
checkCardinality = checkCardinality,
output = generateExpandOutput(attrs, outputs),
joinPlan)
}

// converts a MERGE condition into an EXISTS subquery for runtime filtering
private def toGroupFilterCondition(
relation: DataSourceV2Relation,
source: LogicalPlan,
cond: Expression): Expression = {

val condWithOuterRefs = cond transformUp {
case attr: Attribute if relation.outputSet.contains(attr) => OuterReference(attr)
case other => other
}
val outerRefs = condWithOuterRefs.collect {
case OuterReference(e) => e
}
Exists(Filter(condWithOuterRefs, source), outerRefs)
}

// build a rewrite plan for sources that support row deltas
private def buildWriteDeltaPlan(
relation: DataSourceV2Relation,
Expand Down Expand Up @@ -310,6 +417,26 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
}
}

// converts a MERGE action into an instruction on top of the joined plan for group-based plans
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
case UpdateAction(cond, assignments) =>
val output = assignments.map(_.value) ++ metadataAttrs
Keep(cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
Discard(cond.getOrElse(TrueLiteral))

case InsertAction(cond, assignments) =>
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
val output = assignments.map(_.value) ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)

case other =>
throw new AnalysisException(s"Unexpected action: $other")
}
}

// converts a MERGE action into an instruction on top of the joined plan for delta-based plans
private def toInstruction(
action: MergeAction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition = replaceNullWithFalse(cond))
case rd @ ReplaceData(_, cond, _, _, groupFilterCond, _) =>
val newCond = replaceNullWithFalse(cond)
val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse)
rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond)
case wd @ WriteDelta(_, cond, _, _, _, _) => wd.copy(condition = replaceNullWithFalse(cond))
case d @ DeleteFromTable(_, cond) => d.copy(condition = replaceNullWithFalse(cond))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,18 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre
* This class extracts the following entities:
* - the group-based rewrite plan;
* - the condition that defines matching groups;
* - the group filter condition;
* - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]]
* depending on whether the planning has already happened;
*/
object GroupBasedRowLevelOperation {
type ReturnType = (ReplaceData, Expression, LogicalPlan)
type ReturnType = (ReplaceData, Expression, Option[Expression], LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond, query, _, _) =>
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _),
cond, query, _, groupFilterCond, _) =>
val readRelation = findReadRelation(table, query)
readRelation.map((rd, cond, _))
readRelation.map((rd, cond, groupFilterCond, _))

case _ =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Instruction
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -76,6 +77,15 @@ object MergeRows {
}
}

case class Discard(condition: Expression) extends Instruction with UnaryLike[Expression] {
override def outputs: Seq[Seq[Expression]] = Seq.empty
override def child: Expression = condition

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(condition = newChild)
}
}

case class Split(
condition: Expression,
output: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,15 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery {
* @param condition a condition that defines matching groups
* @param query a query with records that should replace the records that were read
* @param originalTable a plan for the original table for which the row-level command was triggered
* @param groupFilterCondition a condition that can be used to filter groups at runtime
* @param write a logical write, if already constructed
*/
case class ReplaceData(
table: NamedRelation,
condition: Expression,
query: LogicalPlan,
originalTable: NamedRelation,
groupFilterCondition: Option[Expression] = None,
write: Option[Write] = None) extends RowLevelWrite {

override val isByName: Boolean = false
Expand Down
Loading

0 comments on commit c2006f7

Please sign in to comment.