diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala index b4e077671d4e1..32163406ca6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -86,7 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - ReplaceData(writeRelation, cond, remainingRowsPlan, relation) + ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond)) } // build a rewrite plan for sources that support row deltas diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index d7ed2d72c684c..7bfc476d29a1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Expression, IsNotNull, MetadataAttribute, MonotonicallyIncreasingID, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Exists, Expression, IsNotNull, Literal, MetadataAttribute, MonotonicallyIncreasingID, OuterReference, PredicateHelper} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, UpdateAction, WriteDelta} -import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} @@ -123,7 +123,9 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper r, table, source, cond, matchedActions, notMatchedActions, notMatchedBySourceActions) case _ => - throw new AnalysisException("Group-based MERGE commands are not supported yet") + buildReplaceDataPlan( + r, table, source, cond, matchedActions, + notMatchedActions, notMatchedBySourceActions) } case _ => @@ -131,6 +133,111 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } } + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): ReplaceData = { + + // resolve all required metadata attrs that may be used for grouping data on write + // for instance, JDBC data source may cluster data by shard/host before writing + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + val checkCardinality = shouldCheckCardinality(matchedActions) + + // use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded + // use full outer join in all other cases, unmatched source rows may be needed + val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter + val joinPlan = join(readRelation, source, joinType, cond, checkCardinality) + + val mergeRowsPlan = buildReplaceDataMergeRowsPlan( + readRelation, joinPlan, matchedActions, notMatchedActions, + notMatchedBySourceActions, metadataAttrs, checkCardinality) + + // predicates of the ON condition can be used to filter the target table (planning & runtime) + // only if there is no NOT MATCHED BY SOURCE clause + val (pushableCond, groupFilterCond) = if (notMatchedBySourceActions.isEmpty) { + (cond, Some(toGroupFilterCondition(relation, source, cond))) + } else { + (TrueLiteral, None) + } + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond) + } + + private def buildReplaceDataMergeRowsPlan( + targetTable: LogicalPlan, + joinPlan: LogicalPlan, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction], + metadataAttrs: Seq[Attribute], + checkCardinality: Boolean): MergeRows = { + + // target records that were read but did not match any MATCHED or NOT MATCHED BY SOURCE actions + // must be copied over and included in the new state of the table as groups are being replaced + // that's why an extra unconditional instruction that would produce the original row is added + // as the last MATCHED and NOT MATCHED BY SOURCE instruction + // this logic is specific to data sources that replace groups of data + val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output) + + val matchedInstructions = matchedActions.map { action => + toInstruction(action, metadataAttrs) + } :+ keepCarryoverRowsInstruction + + val notMatchedInstructions = notMatchedActions.map { action => + toInstruction(action, metadataAttrs) + } + + val notMatchedBySourceInstructions = notMatchedBySourceActions.map { action => + toInstruction(action, metadataAttrs) + } :+ keepCarryoverRowsInstruction + + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET, joinPlan) + + val outputs = matchedInstructions.flatMap(_.outputs) ++ + notMatchedInstructions.flatMap(_.outputs) ++ + notMatchedBySourceInstructions.flatMap(_.outputs) + + val attrs = targetTable.output + + MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = IsNotNull(rowFromTargetAttr), + matchedInstructions = matchedInstructions, + notMatchedInstructions = notMatchedInstructions, + notMatchedBySourceInstructions = notMatchedBySourceInstructions, + checkCardinality = checkCardinality, + output = generateExpandOutput(attrs, outputs), + joinPlan) + } + + // converts a MERGE condition into an EXISTS subquery for runtime filtering + private def toGroupFilterCondition( + relation: DataSourceV2Relation, + source: LogicalPlan, + cond: Expression): Expression = { + + val condWithOuterRefs = cond transformUp { + case attr: Attribute if relation.outputSet.contains(attr) => OuterReference(attr) + case other => other + } + val outerRefs = condWithOuterRefs.collect { + case OuterReference(e) => e + } + Exists(Filter(condWithOuterRefs, source), outerRefs) + } + // build a rewrite plan for sources that support row deltas private def buildWriteDeltaPlan( relation: DataSourceV2Relation, @@ -310,6 +417,26 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } } + // converts a MERGE action into an instruction on top of the joined plan for group-based plans + private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = { + action match { + case UpdateAction(cond, assignments) => + val output = assignments.map(_.value) ++ metadataAttrs + Keep(cond.getOrElse(TrueLiteral), output) + + case DeleteAction(cond) => + Discard(cond.getOrElse(TrueLiteral)) + + case InsertAction(cond, assignments) => + val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) + val output = assignments.map(_.value) ++ metadataValues + Keep(cond.getOrElse(TrueLiteral), output) + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } + // converts a MERGE action into an instruction on top of the joined plan for delta-based plans private def toInstruction( action: MergeAction, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 9d81ba2fadbe6..c8fe6c1524b2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -54,7 +54,10 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition = replaceNullWithFalse(cond)) + case rd @ ReplaceData(_, cond, _, _, groupFilterCond, _) => + val newCond = replaceNullWithFalse(cond) + val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse) + rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond) case wd @ WriteDelta(_, cond, _, _, _, _) => wd.copy(condition = replaceNullWithFalse(cond)) case d @ DeleteFromTable(_, cond) => d.copy(condition = replaceNullWithFalse(cond)) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ad851daedb56b..8b4db79ff30e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -427,16 +427,18 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre * This class extracts the following entities: * - the group-based rewrite plan; * - the condition that defines matching groups; + * - the group filter condition; * - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]] * depending on whether the planning has already happened; */ object GroupBasedRowLevelOperation { - type ReturnType = (ReplaceData, Expression, LogicalPlan) + type ReturnType = (ReplaceData, Expression, Option[Expression], LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond, query, _, _) => + case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), + cond, query, _, groupFilterCond, _) => val readRelation = findReadRelation(table, query) - readRelation.map((rd, cond, _)) + readRelation.map((rd, cond, groupFilterCond, _)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala index b3785d11e5c0c..6a664eb018be5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Instruction +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types.DataType @@ -76,6 +77,15 @@ object MergeRows { } } + case class Discard(condition: Expression) extends Instruction with UnaryLike[Expression] { + override def outputs: Seq[Seq[Expression]] = Seq.empty + override def child: Expression = condition + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(condition = newChild) + } + } + case class Split( condition: Expression, output: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index bd646b7f69270..739ffa487e393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -213,6 +213,7 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery { * @param condition a condition that defines matching groups * @param query a query with records that should replace the records that were read * @param originalTable a plan for the original table for which the row-level command was triggered + * @param groupFilterCondition a condition that can be used to filter groups at runtime * @param write a logical write, if already constructed */ case class ReplaceData( @@ -220,6 +221,7 @@ case class ReplaceData( condition: Expression, query: LogicalPlan, originalTable: NamedRelation, + groupFilterCondition: Option[Expression] = None, write: Option[Write] = None) extends RowLevelWrite { override val isByName: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f7b18e6a7a053..542ac2e674864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -310,7 +310,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw SparkException.internalError("Unexpected table relation: " + other) } - case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, Some(write)) => + case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, _, Some(write)) => // use the original relation to refresh the cache ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala index 48dee3f652c6f..a27740c280434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala @@ -39,7 +39,7 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { // push down the filter from the command condition instead of the filter in the rewrite plan, // which is negated for data sources that only support replacing groups of data (e.g. files) - case GroupBasedRowLevelOperation(rd: ReplaceData, cond, relation: DataSourceV2Relation) => + case GroupBasedRowLevelOperation(rd: ReplaceData, cond, _, relation: DataSourceV2Relation) => val table = relation.table.asRowLevelOperationTable val scanBuilder = table.newScanBuilder(relation.options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index db381d1ea4da8..ba300df9e245d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Projection import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan @@ -100,6 +100,9 @@ case class MergeRowsExec( case Keep(cond, output) => KeepExec(createPredicate(cond), createProjection(output)) + case Discard(cond) => + DiscardExec(createPredicate(cond)) + case Split(cond, output, otherOutput) => SplitExec(createPredicate(cond), createProjection(output), createProjection(otherOutput)) @@ -116,6 +119,8 @@ case class MergeRowsExec( def apply(row: InternalRow): InternalRow = projection.apply(row) } + case class DiscardExec(condition: BasePredicate) extends InstructionExec + case class SplitExec( condition: BasePredicate, projection: Projection, @@ -206,6 +211,9 @@ case class MergeRowsExec( case keep: KeepExec => return keep.apply(row) + case _: DiscardExec => + return null + case split: SplitExec => cachedExtraRow = split.projectExtraRow(row) return split.projectRow(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala index abb9d728c78df..2a89ed8f80c99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala @@ -73,7 +73,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case rd @ ReplaceData(_, cond, _, originalTable, _) => + case rd @ ReplaceData(_, cond, _, originalTable, _, _) => val command = rd.operation.command Some(rd, command, cond, originalTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 8f7fed561c0bc..de43d67e621c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -95,7 +95,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) - case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) => + case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) => val rowSchema = StructType.fromAttributes(rd.dataInput) val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) val write = writeBuilder.build() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 2877ff46edb52..cdff0249ba707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.dynamicpruning +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering +import org.apache.spark.sql.connector.write.RowLevelOperation.Command +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} /** @@ -44,7 +48,7 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { // apply special dynamic filtering only for group-based row-level operations - case GroupBasedRowLevelOperation(replaceData, cond, + case GroupBasedRowLevelOperation(replaceData, _, Some(cond), DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral => @@ -55,7 +59,8 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla // in order to leverage a regular batch scan in the group filter query val originalTable = r.relation.table.asRowLevelOperationTable.table val relation = r.relation.copy(table = originalTable) - val matchingRowsPlan = buildMatchingRowsPlan(relation, cond) + val command = replaceData.operation.command + val matchingRowsPlan = buildMatchingRowsPlan(relation, cond, command) val filterAttrs = scan.filterAttributes val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) @@ -71,9 +76,19 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla private def buildMatchingRowsPlan( relation: DataSourceV2Relation, - cond: Expression): LogicalPlan = { + cond: Expression, + command: Command): LogicalPlan = { - val matchingRowsPlan = Filter(cond, relation) + val matchingRowsPlan = command match { + case DELETE => + Filter(cond, relation) + case UPDATE => + throw new AnalysisException("Group-based UPDATE operations are currently not supported") + case MERGE => + // rewrite the group filter subquery as joins + val filter = Filter(cond, relation) + RewritePredicateSubquery(filter) + } // clone the relation and assign new expr IDs to avoid conflicts matchingRowsPlan transformUpWithNewOutput { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala index 33c6e4f5be61a..6c80d46b0ef64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -18,13 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.InSubqueryExec -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { @@ -37,10 +31,10 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { |{ "id": 3, "salary": 120, "dep": 'hr' } |""".stripMargin) - executeDeleteAndCheckScans( + executeAndCheckScans( s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "salary INT, dep STRING") + groupFilterScanSchema = Some("salary INT, dep STRING")) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -64,7 +58,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { val deletedDepDF = Seq(Some("software"), None).toDF() deletedDepDF.createOrReplaceTempView("deleted_dep") - executeDeleteAndCheckScans( + executeAndCheckScans( s"""DELETE FROM $tableNameAsString |WHERE | id IN (SELECT * FROM deleted_id) @@ -72,7 +66,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { | dep IN (SELECT * FROM deleted_dep) |""".stripMargin, primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "id INT, dep STRING") + groupFilterScanSchema = Some("id INT, dep STRING")) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -117,10 +111,10 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { val deletedIdDF = Seq(Some(1), None).toDF() deletedIdDF.createOrReplaceTempView("deleted_id") - executeDeleteAndCheckScans( + executeAndCheckScans( s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "id INT, dep STRING") + groupFilterScanSchema = Some("id INT, dep STRING")) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -129,40 +123,4 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkReplacedPartitions(Seq("hr")) } } - - private def executeDeleteAndCheckScans( - query: String, - primaryScanSchema: String, - groupFilterScanSchema: String): Unit = { - - val executedPlan = executeAndKeepPlan { - sql(query) - } - - val primaryScan = collect(executedPlan) { - case s: BatchScanExec => s - }.head - assert(DataTypeUtils.sameType(primaryScan.schema, StructType.fromDDL(primaryScanSchema))) - - primaryScan.runtimeFilters match { - case Seq(DynamicPruningExpression(child: InSubqueryExec)) => - val groupFilterScan = collect(child.plan) { - case s: BatchScanExec => s - }.head - assert(DataTypeUtils.sameType(groupFilterScan.schema, - StructType.fromDDL(groupFilterScanSchema))) - - case _ => - fail("could not find group filter scan") - } - } - - private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { - val actualPartitions = table.replacedPartitions.map { - case Seq(partValue: UTF8String) => partValue.toString - case Seq(partValue) => partValue - case other => fail(s"expected only one partition value: $other" ) - } - assert(actualPartitions == expectedPartitions, "replaced partitions must match") - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala new file mode 100644 index 0000000000000..ebc34ae006e6e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.internal.SQLConf + +class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase { + + import testImplicits._ + + test("merge runtime filtering is disabled with NOT MATCHED BY SOURCE clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "hr" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "software" } + |{ "pk": 5, "salary": 500, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq(1, 2, 3, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + executeAndCheckScans( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr') + |WHEN NOT MATCHED BY SOURCE THEN + | DELETE + |""".stripMargin, + primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = None) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 201, "hr"), // update + Row(3, 301, "hr"), // update + Row(6, 0, "hr"))) // insert + + checkReplacedPartitions(Seq("hr", "software")) + } + } + + test("merge runtime group filtering (DPP enabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + checkMergeRuntimeGroupFiltering() + } + } + + test("merge runtime group filtering (DPP disabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { + checkMergeRuntimeGroupFiltering() + } + } + + test("merge runtime group filtering (AQE enabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + checkMergeRuntimeGroupFiltering() + } + } + + test("merge runtime group filtering (AQE disabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + checkMergeRuntimeGroupFiltering() + } + } + + private def checkMergeRuntimeGroupFiltering(): Unit = { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "hr" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "software" } + |{ "pk": 5, "salary": 500, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq(1, 2, 3, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + executeAndCheckScans( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr') + |""".stripMargin, + primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = Some("pk INT, dep STRING")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 201, "hr"), // update + Row(3, 301, "hr"), // update + Row(4, 400, "software"), // unchanged + Row(5, 500, "software"), // unchanged + Row(6, 0, "hr"))) // insert + + checkReplacedPartitions(Seq("hr")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 0cb94709898ce..1f8c20947b893 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -22,15 +22,19 @@ import java.util.Collections import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, Encoders, QueryTest} +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String abstract class RowLevelOperationSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { @@ -114,4 +118,47 @@ abstract class RowLevelOperationSuiteBase stripAQEPlan(executedPlan) } + + protected def executeAndCheckScans( + query: String, + primaryScanSchema: String, + groupFilterScanSchema: Option[String]): Unit = { + + val executedPlan = executeAndKeepPlan { + sql(query) + } + + val primaryScan = collect(executedPlan) { + case s: BatchScanExec => s + }.head + assert(DataTypeUtils.sameType(primaryScan.schema, StructType.fromDDL(primaryScanSchema))) + + val groupFilterScan = primaryScan.runtimeFilters match { + case Seq(DynamicPruningExpression(child: InSubqueryExec)) => + find(child.plan) { + case _: BatchScanExec => true + case _ => false + } + case _ => + None + } + + groupFilterScanSchema match { + case Some(filterSchema) => + assert(groupFilterScan.isDefined, "could not find group filter scan") + assert(DataTypeUtils.sameType(groupFilterScan.get.schema, StructType.fromDDL(filterSchema))) + + case None => + assert(groupFilterScan.isEmpty, "should not be any group filter scans") + } + } + + protected def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { + val actualPartitions = table.replacedPartitions.map { + case Seq(partValue: UTF8String) => partValue.toString + case Seq(partValue) => partValue + case other => fail(s"expected only one partition value: $other" ) + } + assert(actualPartitions == expectedPartitions, "replaced partitions must match") + } }