diff --git a/core/src/main/resources/error/README.md b/core/src/main/resources/error/README.md index 147828ec9cdb2..838991c2b6ab9 100644 --- a/core/src/main/resources/error/README.md +++ b/core/src/main/resources/error/README.md @@ -494,6 +494,7 @@ The following SQLSTATEs are collated from: |23525 |23 |Constraint Violation |525 |A violation of a constraint imposed by an XML values index occurred.|DB2 |N |DB2 | |23526 |23 |Constraint Violation |526 |An XML values index could not be created because the table data contains values that violate a constraint imposed by the index.|DB2 |N |DB2 | |23P01 |23 |Integrity Constraint Violation |P01 |exclusion_violation |PostgreSQL |N |PostgreSQL | +|23K01 |23 |Constraint Violation |K01 |MERGE cardinality violation |Spark |N |Spark | |24000 |24 |invalid cursor state |000 |(no subclass) |SQL/Foundation |Y |SQL/Foundation PostgreSQL Redshift Oracle SQL Server | |24501 |24 |Invalid Cursor State |501 |The identified cursor is not open. |DB2 |N |DB2 | |24502 |24 |Invalid Cursor State |502 |The cursor identified in an OPEN statement is already open. |DB2 |N |DB2 | diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 8c3c076ce748b..75d19958c2fc8 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1539,6 +1539,13 @@ "Parse Mode: . To process malformed records as null result, try setting the option 'mode' as 'PERMISSIVE'." ] }, + "MERGE_CARDINALITY_VIOLATION" : { + "message" : [ + "The ON search condition of the MERGE statement matched a single row from the target table with multiple rows of the source table.", + "This could result in the target row being operated on more than once with an update or delete operation and is not allowed." + ], + "sqlState" : "23K01" + }, "MISSING_AGGREGATION" : { "message" : [ "The non-aggregating expression is based on columns which are not participating in the GROUP BY clause.", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index aa1b9d0e8fdcf..489040a037b9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -321,6 +321,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveRowLevelCommandAssignments :: RewriteDeleteFromTable :: RewriteUpdateTable :: + RewriteMergeIntoTable :: typeCoercionRules ++ Seq( ResolveWithCTE, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala new file mode 100644 index 0000000000000..d7ed2d72c684c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Expression, IsNotNull, MetadataAttribute, MonotonicallyIncreasingID, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, UpdateAction, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A rule that rewrites MERGE operations using plans that operate on individual or groups of rows. + * + * This rule assumes the commands have been fully resolved and all assignments have been aligned. + */ +object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper { + + private final val ROW_FROM_SOURCE = "__row_from_source" + private final val ROW_FROM_TARGET = "__row_from_target" + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, + notMatchedBySourceActions) if m.resolved && m.rewritable && m.aligned && + matchedActions.isEmpty && notMatchedActions.size == 1 && + notMatchedBySourceActions.isEmpty => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // NOT MATCHED conditions may only refer to columns in source so they can be pushed down + val insertAction = notMatchedActions.head.asInstanceOf[InsertAction] + val filteredSource = insertAction.condition match { + case Some(insertCond) => Filter(insertCond, source) + case None => source + } + + // there is only one NOT MATCHED action, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level MERGE operation + // only unmatched source rows that match the condition are appended to the table + val joinPlan = Join(filteredSource, r, LeftAnti, Some(cond), JoinHint.NONE) + + val output = insertAction.assignments.map(_.value) + val outputColNames = r.output.map(_.name) + val projectList = output.zip(outputColNames).map { case (expr, name) => + Alias(expr, name)() + } + val project = Project(projectList, joinPlan) + + AppendData.byPosition(r, project) + + case _ => + m + } + + case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, + notMatchedBySourceActions) if m.resolved && m.rewritable && m.aligned && + matchedActions.isEmpty && notMatchedBySourceActions.isEmpty => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // there are only NOT MATCHED actions, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level MERGE operation + // only unmatched source rows that match action conditions are appended to the table + val joinPlan = Join(source, r, LeftAnti, Some(cond), JoinHint.NONE) + + val notMatchedInstructions = notMatchedActions.map { + case InsertAction(cond, assignments) => + Keep(cond.getOrElse(TrueLiteral), assignments.map(_.value)) + case other => + throw new AnalysisException(s"Unexpected WHEN NOT MATCHED action: $other") + } + + val outputs = notMatchedInstructions.flatMap(_.outputs) + + // merge rows as there are multiple NOT MATCHED actions + val mergeRows = MergeRows( + isSourceRowPresent = TrueLiteral, + isTargetRowPresent = FalseLiteral, + matchedInstructions = Nil, + notMatchedInstructions = notMatchedInstructions, + notMatchedBySourceInstructions = Nil, + checkCardinality = false, + output = generateExpandOutput(r.output, outputs), + joinPlan) + + AppendData.byPosition(r, mergeRows) + + case _ => + m + } + + case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, + notMatchedBySourceActions) if m.resolved && m.rewritable && m.aligned => + + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty()) + table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan( + r, table, source, cond, matchedActions, + notMatchedActions, notMatchedBySourceActions) + case _ => + throw new AnalysisException("Group-based MERGE commands are not supported yet") + } + + case _ => + m + } + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): WriteDelta = { + + val operation = operationTable.operation.asInstanceOf[SupportsDelta] + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) + + // if there is no NOT MATCHED BY SOURCE clause, predicates of the ON condition that + // reference only the target table can be pushed down + val (filteredReadRelation, joinCond) = if (notMatchedBySourceActions.isEmpty) { + pushDownTargetPredicates(readRelation, cond) + } else { + (readRelation, cond) + } + + val checkCardinality = shouldCheckCardinality(matchedActions) + + val joinType = chooseWriteDeltaJoinType(notMatchedActions, notMatchedBySourceActions) + val joinPlan = join(filteredReadRelation, source, joinType, joinCond, checkCardinality) + + val mergeRowsPlan = buildWriteDeltaMergeRowsPlan( + readRelation, joinPlan, matchedActions, notMatchedActions, + notMatchedBySourceActions, rowIdAttrs, checkCardinality, + operation.representUpdateAsDeleteAndInsert) + + // build a plan to write the row delta to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(mergeRowsPlan, rowAttrs, rowIdAttrs, metadataAttrs) + WriteDelta(writeRelation, cond, mergeRowsPlan, relation, projections) + } + + private def chooseWriteDeltaJoinType( + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): JoinType = { + + val unmatchedTargetRowsRequired = notMatchedBySourceActions.nonEmpty + val unmatchedSourceRowsRequired = notMatchedActions.nonEmpty + + if (unmatchedTargetRowsRequired && unmatchedSourceRowsRequired) { + FullOuter + } else if (unmatchedTargetRowsRequired) { + LeftOuter + } else if (unmatchedSourceRowsRequired) { + RightOuter + } else { + Inner + } + } + + private def buildWriteDeltaMergeRowsPlan( + targetTable: DataSourceV2Relation, + joinPlan: LogicalPlan, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction], + rowIdAttrs: Seq[Attribute], + checkCardinality: Boolean, + splitUpdates: Boolean): MergeRows = { + + val (metadataAttrs, rowAttrs) = targetTable.output.partition { attr => + MetadataAttribute.isValid(attr.metadata) + } + + val originalRowIdValues = if (splitUpdates) { + Seq.empty + } else { + // original row ID values must be preserved and passed back to the table to encode updates + // if there are any assignments to row ID attributes, add extra columns for original values + val updateAssignments = (matchedActions ++ notMatchedBySourceActions).flatMap { + case UpdateAction(_, assignments) => assignments + case _ => Nil + } + buildOriginalRowIdValues(rowIdAttrs, updateAssignments) + } + + val matchedInstructions = matchedActions.map { action => + toInstruction(action, rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues, splitUpdates) + } + + val notMatchedInstructions = notMatchedActions.map { action => + toInstruction(action, rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues, splitUpdates) + } + + val notMatchedBySourceInstructions = notMatchedBySourceActions.map { action => + toInstruction(action, rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues, splitUpdates) + } + + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET, joinPlan) + + val outputs = matchedInstructions.flatMap(_.outputs) ++ + notMatchedInstructions.flatMap(_.outputs) ++ + notMatchedBySourceInstructions.flatMap(_.outputs) + + val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() + val originalRowIdAttrs = originalRowIdValues.map(_.toAttribute) + val attrs = Seq(operationTypeAttr) ++ targetTable.output ++ originalRowIdAttrs + + MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = IsNotNull(rowFromTargetAttr), + matchedInstructions = matchedInstructions, + notMatchedInstructions = notMatchedInstructions, + notMatchedBySourceInstructions = notMatchedBySourceInstructions, + checkCardinality = checkCardinality, + output = generateExpandOutput(attrs, outputs), + joinPlan) + } + + private def pushDownTargetPredicates( + targetTable: LogicalPlan, + cond: Expression): (LogicalPlan, Expression) = { + + val predicates = splitConjunctivePredicates(cond) + val (targetPredicates, joinPredicates) = predicates.partition { predicate => + predicate.references.subsetOf(targetTable.outputSet) + } + val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral) + val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral) + (Filter(targetCond, targetTable), joinCond) + } + + private def join( + targetTable: LogicalPlan, + source: LogicalPlan, + joinType: JoinType, + joinCond: Expression, + checkCardinality: Boolean): LogicalPlan = { + + // project an extra column to check if a target row exists after the join + // if needed, project a synthetic row ID used to perform the cardinality check later + val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)() + val targetTableProjExprs = if (checkCardinality) { + val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)() + targetTable.output ++ Seq(rowFromTarget, rowId) + } else { + targetTable.output :+ rowFromTarget + } + val targetTableProj = Project(targetTableProjExprs, targetTable) + + // project an extra column to check if a source row exists after the join + val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)() + val sourceTableProjExprs = source.output :+ rowFromSource + val sourceTableProj = Project(sourceTableProjExprs, source) + + // the cardinality check prohibits broadcasting and replicating the target table + // all matches for a particular target row must be in one partition + val joinHint = if (checkCardinality) { + JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))), rightHint = None) + } else { + JoinHint.NONE + } + Join(targetTableProj, sourceTableProj, joinType, Some(joinCond), joinHint) + } + + // skip the cardinality check in these cases: + // - no MATCHED actions + // - there is only one MATCHED action and it is an unconditional DELETE + private def shouldCheckCardinality(matchedActions: Seq[MergeAction]): Boolean = { + matchedActions match { + case Nil => false + case Seq(DeleteAction(None)) => false + case _ => true + } + } + + // converts a MERGE action into an instruction on top of the joined plan for delta-based plans + private def toInstruction( + action: MergeAction, + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Alias], + splitUpdates: Boolean): Instruction = { + + action match { + case UpdateAction(cond, assignments) if splitUpdates => + val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues) + val otherOutput = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues) + Split(cond.getOrElse(TrueLiteral), output, otherOutput) + + case UpdateAction(cond, assignments) => + val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues) + Keep(cond.getOrElse(TrueLiteral), output) + + case DeleteAction(cond) => + val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues) + Keep(cond.getOrElse(TrueLiteral), output) + + case InsertAction(cond, assignments) => + val output = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues) + Keep(cond.getOrElse(TrueLiteral), output) + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index ea40a9d5ca084..9d120e6fc2b06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.catalyst.ProjectingInternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors @@ -91,12 +92,17 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { rowIdAttrs } + protected def resolveAttrRef(name: String, plan: LogicalPlan): AttributeReference = { + V2ExpressionUtils.resolveRef[AttributeReference](FieldReference(name), plan) + } + protected def deltaDeleteOutput( rowAttrs: Seq[Attribute], rowIdAttrs: Seq[Attribute], - metadataAttrs: Seq[Attribute]): Seq[Expression] = { + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = { val rowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs) - Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataAttrs + Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues } private def buildDeltaDeleteRowValues( @@ -112,10 +118,20 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { } protected def deltaInsertOutput( - rowValues: Seq[Expression], - metadataAttrs: Seq[Attribute]): Seq[Expression] = { - val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) - Seq(Literal(INSERT_OPERATION)) ++ rowValues ++ metadataValues + assignments: Seq[Assignment], + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = { + val rowValues = assignments.map(_.value) + val extraNullValues = (metadataAttrs ++ originalRowIdValues).map(e => Literal(null, e.dataType)) + Seq(Literal(INSERT_OPERATION)) ++ rowValues ++ extraNullValues + } + + protected def deltaUpdateOutput( + assignments: Seq[Assignment], + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Expression]): Seq[Expression] = { + val rowValues = assignments.map(_.value) + Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues } protected def buildWriteDeltaProjections( @@ -167,4 +183,36 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def findColOrdinal(plan: LogicalPlan, name: String): Int = { plan.output.indexWhere(attr => conf.resolver(attr.name, name)) } + + protected def buildOriginalRowIdValues( + rowIdAttrs: Seq[Attribute], + assignments: Seq[Assignment]): Seq[Alias] = { + val rowIdAttrSet = AttributeSet(rowIdAttrs) + assignments.flatMap { assignment => + val key = assignment.key.asInstanceOf[Attribute] + val value = assignment.value + if (rowIdAttrSet.contains(key) && !key.semanticEquals(value)) { + Some(Alias(key, ORIGINAL_ROW_ID_VALUE_PREFIX + key.name)()) + } else { + None + } + } + } + + // generates output attributes with fresh expr IDs and correct nullability for nodes like Expand + // and MergeRows where there are multiple outputs for each input row + protected def generateExpandOutput( + attrs: Seq[Attribute], + outputs: Seq[Seq[Expression]]): Seq[Attribute] = { + + // build a correct nullability map for output attributes + // an attribute is nullable if at least one output may produce null + val nullabilityMap = attrs.indices.map { index => + index -> outputs.exists(output => output(index).nullable) + }.toMap + + attrs.zipWithIndex.map { case (attr, index) => + AttributeReference(attr.name, attr.dataType, nullabilityMap(index), attr.metadata)() + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 048f07395b302..770e2c596372f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, Literal, MetadataAttribute} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, MetadataAttribute} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ @@ -108,16 +108,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // original row ID values must be preserved and passed back to the table to encode updates // if there are any assignments to row ID attributes, add extra columns for the original values - val rowIdAttrSet = AttributeSet(rowIdAttrs) - val originalRowIdValues = assignments.flatMap { assignment => - val key = assignment.key.asInstanceOf[Attribute] - val value = assignment.value - if (rowIdAttrSet.contains(key) && !key.semanticEquals(value)) { - Some(Alias(key, ORIGINAL_ROW_ID_VALUE_PREFIX + key.name)()) - } else { - None - } - } + val originalRowIdValues = buildOriginalRowIdValues(rowIdAttrs, assignments) val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)() @@ -133,27 +124,11 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { MetadataAttribute.isValid(attr.metadata) } val deleteOutput = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs) - val insertOutput = deltaInsertOutput(assignments.map(_.value), metadataAttrs) - val output = buildDeletesAndInsertsOutput(matchedRowsPlan, deleteOutput, insertOutput) - Expand(Seq(deleteOutput, insertOutput), output, matchedRowsPlan) - } - - private def buildDeletesAndInsertsOutput( - child: LogicalPlan, - deleteOutput: Seq[Expression], - insertOutput: Seq[Expression]): Seq[Attribute] = { - + val insertOutput = deltaInsertOutput(assignments, metadataAttrs) + val outputs = Seq(deleteOutput, insertOutput) val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() - val attrs = operationTypeAttr +: child.output - - // build a correct nullability map for output attributes - // an attribute is nullable if at least one output projection may produce null - val nullabilityMap = attrs.indices.map { index => - index -> (deleteOutput(index).nullable || insertOutput(index).nullable) - }.toMap - - attrs.zipWithIndex.map { case (attr, index) => - AttributeReference(attr.name, attr.dataType, nullabilityMap(index))() - } + val attrs = operationTypeAttr +: matchedRowsPlan.output + val expandOutput = generateExpandOutput(attrs, outputs) + Expand(outputs, expandOutput, matchedRowsPlan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala new file mode 100644 index 0000000000000..b3785d11e5c0c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.Instruction +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.types.DataType + +case class MergeRows( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedInstructions: Seq[Instruction], + notMatchedInstructions: Seq[Instruction], + notMatchedBySourceInstructions: Seq[Instruction], + checkCardinality: Boolean, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} + +object MergeRows { + final val ROW_ID = "__row_id" + + /** + * When a MERGE operation is rewritten, the target table is joined with the source and each + * MATCHED/NOT MATCHED/NOT MATCHED BY SOURCE clause is converted into a corresponding instruction + * on top of the joined plan. The purpose of an instruction is to derive an output row + * based on a joined row. + * + * Instructions are valid expressions so that they will be properly transformed by the analyzer + * and optimizer. + */ + sealed trait Instruction extends Expression with Unevaluable { + def condition: Expression + def outputs: Seq[Seq[Expression]] + override def nullable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException("dataType") + } + + case class Keep(condition: Expression, output: Seq[Expression]) extends Instruction { + def children: Seq[Expression] = condition +: output + override def outputs: Seq[Seq[Expression]] = Seq(output) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(condition = newChildren.head, output = newChildren.tail) + } + } + + case class Split( + condition: Expression, + output: Seq[Expression], + otherOutput: Seq[Expression]) extends Instruction { + + def children: Seq[Expression] = Seq(condition) ++ output ++ otherOutput + override def outputs: Seq[Seq[Expression]] = Seq(output, otherOutput) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + val newCondition = newChildren.head + val newOutput = newChildren.slice(from = 1, until = output.size + 1) + val newOtherOutput = newChildren.takeRight(otherOutput.size) + copy(condition = newCondition, output = newOutput, otherOutput = newOtherOutput) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index fd09e99b9ee64..c2b81caa0e5f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2849,4 +2849,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { "data" -> data, "enumString" -> enumString)) } + + def mergeCardinalityViolationError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "MERGE_CARDINALITY_VIOLATION", + messageParameters = Map.empty) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 471a9393b7d86..ab672318c2c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -319,6 +319,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // use the original relation to refresh the cache WriteDeltaExec(planLater(query), refreshCache(r), projections, write) :: Nil + case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, + notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => + MergeRowsExec(isSourceRowPresent, isTargetRowPresent, matchedInstructions, + notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, + output, planLater(child)) :: Nil + case WriteToContinuousDataSource(writer, query, customMetrics) => WriteToContinuousDataSourceExec(writer, planLater(query), customMetrics) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala new file mode 100644 index 0000000000000..db381d1ea4da8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.roaringbitmap.longlong.Roaring64Bitmap + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +case class MergeRowsExec( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedInstructions: Seq[Instruction], + notMatchedInstructions: Seq[Instruction], + notMatchedBySourceInstructions: Seq[Instruction], + checkCardinality: Boolean, + output: Seq[Attribute], + child: SparkPlan) extends UnaryExecNode { + + @transient override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + @transient override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions(processPartition) + } + + private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val isSourceRowPresentPred = createPredicate(isSourceRowPresent) + val isTargetRowPresentPred = createPredicate(isTargetRowPresent) + + val matchedInstructionExecs = planInstructions(matchedInstructions) + val notMatchedInstructionExecs = planInstructions(notMatchedInstructions) + val notMatchedBySourceInstructionExecs = planInstructions(notMatchedBySourceInstructions) + + val cardinalityValidator = if (checkCardinality) { + val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID)) + assert(rowIdOrdinal != -1, "Cannot find row ID attr") + BitmapCardinalityValidator(rowIdOrdinal) + } else { + NoopCardinalityValidator + } + + val mergeIterator = new MergeRowIterator( + rowIterator, cardinalityValidator, isTargetRowPresentPred, isSourceRowPresentPred, + matchedInstructionExecs, notMatchedInstructionExecs, notMatchedBySourceInstructionExecs) + + // null indicates a record must be discarded + mergeIterator.filter(_ != null) + } + + private def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, child.output) + } + + private def createPredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, child.output) + } + + private def planInstructions(instructions: Seq[Instruction]): Seq[InstructionExec] = { + instructions.map { + case Keep(cond, output) => + KeepExec(createPredicate(cond), createProjection(output)) + + case Split(cond, output, otherOutput) => + SplitExec(createPredicate(cond), createProjection(output), createProjection(otherOutput)) + + case other => + throw new AnalysisException(s"Unexpected instruction: $other") + } + } + + sealed trait InstructionExec { + def condition: BasePredicate + } + + case class KeepExec(condition: BasePredicate, projection: Projection) extends InstructionExec { + def apply(row: InternalRow): InternalRow = projection.apply(row) + } + + case class SplitExec( + condition: BasePredicate, + projection: Projection, + otherProjection: Projection) extends InstructionExec { + def projectRow(row: InternalRow): InternalRow = projection.apply(row) + def projectExtraRow(row: InternalRow): InternalRow = otherProjection.apply(row) + } + + sealed trait CardinalityValidator { + def validate(row: InternalRow): Unit + } + + object NoopCardinalityValidator extends CardinalityValidator { + def validate(row: InternalRow): Unit = {} + } + + /** + * A simple cardinality validator that keeps track of seen row IDs in a roaring bitmap. + * This validator assumes the target table is never broadcasted or replicated, which guarantees + * matches for one target row are always co-located in the same partition. + * + * IDs are generated by [[org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID]]. + */ + case class BitmapCardinalityValidator(rowIdOrdinal: Int) extends CardinalityValidator { + // use Roaring64Bitmap as row IDs generated by MonotonicallyIncreasingID are 64-bit integers + private val matchedRowIds = new Roaring64Bitmap() + + override def validate(row: InternalRow): Unit = { + val currentRowId = row.getLong(rowIdOrdinal) + if (matchedRowIds.contains(currentRowId)) { + throw QueryExecutionErrors.mergeCardinalityViolationError() + } + matchedRowIds.add(currentRowId) + } + } + + /** + * An iterator that acts on joined target and source rows and computes deletes, updates and + * inserts according to provided MERGE instructions. + * + * If a particular joined row should be discarded, this iterator returns null. + */ + class MergeRowIterator( + private val rowIterator: Iterator[InternalRow], + private val cardinalityValidator: CardinalityValidator, + private val isTargetRowPresentPred: BasePredicate, + private val isSourceRowPresentPred: BasePredicate, + private val matchedInstructions: Seq[InstructionExec], + private val notMatchedInstructions: Seq[InstructionExec], + private val notMatchedBySourceInstructions: Seq[InstructionExec]) + extends Iterator[InternalRow] { + + var cachedExtraRow: InternalRow = _ + + override def hasNext: Boolean = cachedExtraRow != null || rowIterator.hasNext + + override def next(): InternalRow = { + if (cachedExtraRow != null) { + val extraRow = cachedExtraRow + cachedExtraRow = null + return extraRow + } + + val row = rowIterator.next() + + val isSourceRowPresent = isSourceRowPresentPred.eval(row) + val isTargetRowPresent = isTargetRowPresentPred.eval(row) + + if (isTargetRowPresent && isSourceRowPresent) { + cardinalityValidator.validate(row) + applyInstructions(row, matchedInstructions) + } else if (isSourceRowPresent) { + applyInstructions(row, notMatchedInstructions) + } else if (isTargetRowPresent) { + applyInstructions(row, notMatchedBySourceInstructions) + } else { + null + } + } + + private def applyInstructions( + row: InternalRow, + instructions: Seq[InstructionExec]): InternalRow = { + + for (instruction <- instructions) { + if (instruction.condition.eval(row)) { + instruction match { + case keep: KeepExec => + return keep.apply(row) + + case split: SplitExec => + cachedExtraRow = split.projectExtraRow(row) + return split.projectRow(row) + } + } + } + + null + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index ff67485ce8a08..11e22a744f312 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} -import org.apache.spark.sql.util.QueryExecutionListener abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -551,24 +549,4 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { fail("unexpected executed plan: " + other) } } - - // executes an operation and keeps the executed plan - protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - var executedPlan: SparkPlan = null - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - executedPlan = qe.executedPlan - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - } - } - spark.listenerManager.register(listener) - - func - - sparkContext.listenerBus.waitUntilEmpty() - - stripAQEPlan(executedPlan) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala new file mode 100644 index 0000000000000..fcbd655123934 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class DeltaBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase { + + override protected lazy val extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("supports-deltas", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala new file mode 100644 index 0000000000000..4aa9b7c278559 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite extends MergeIntoTableSuiteBase { + + override protected lazy val extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("supports-deltas", "true") + props.put("split-updates", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala new file mode 100644 index 0000000000000..575cd29c993aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -0,0 +1,1344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.SparkException +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue} +import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType} + +abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { + + import testImplicits._ + + test("merge into empty table with NOT MATCHED clause") { + withTempView("source") { + createTable("pk INT NOT NULL, salary INT, dep STRING") + + val sourceRows = Seq( + (1, 100, "hr"), + (2, 200, "finance"), + (3, 300, "hr")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), // insert + Row(2, 200, "finance"), // insert + Row(3, 300, "hr"))) // insert + } + } + + test("merge into empty table with conditional NOT MATCHED clause") { + withTempView("source") { + createTable("pk INT NOT NULL, salary INT, dep STRING") + + val sourceRows = Seq( + (1, 100, "hr"), + (2, 200, "finance"), + (3, 300, "hr")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED AND s.pk >= 2 THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "finance"), // insert + Row(3, 300, "hr"))) // insert + } + } + + test("merge into with conditional WHEN MATCHED clause (update)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "corrupted" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 100, "software"), + (2, 200, "finance"), + (3, 300, "software")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND s.pk = 2 THEN + | UPDATE SET * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), // unchanged + Row(2, 200, "finance"))) // update + } + } + + test("merge into with conditional WHEN MATCHED clause (delete)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "corrupted" } + |""".stripMargin) + + Seq(1, 2, 3).toDF("pk").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.salary = 200 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, 100, "hr"))) // unchanged + } + } + + test("merge into with assignments to primary key in NOT MATCHED BY SOURCE") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "finance" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 100, "software"), + (5, 500, "finance")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = -1 + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.pk = -1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // update (matched) + Row(-1, 200, "finance"))) // update (not matched by source) + } + } + + test("merge into with assignments to primary key in MATCHED") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "finance" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 100, "software"), + (5, 500, "finance")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.pk = -1 + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.salary = -1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(-1, 100, "hr"), // update (matched) + Row(2, -1, "finance"))) // update (not matched by source) + } + } + + test("merge with all types of clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(3, 301, "hr"), // update + Row(4, 401, "hr"), // update + Row(5, 501, "hr"), // update + Row(6, 0, "new"))) // insert + } + } + + test("merge with all types of clauses (update and insert star)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (2, 201, "support"), + (4, 401, "support"), + (5, 501, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.pk = 1 THEN + | UPDATE SET * + |WHEN NOT MATCHED AND s.pk = 4 THEN + | INSERT * + |WHEN NOT MATCHED BY SOURCE AND t.pk = t.salary / 100 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), // update + Row(2, 200, "software"), // unchanged + Row(4, 401, "support"))) // insert + } + } + + test("merge with all types of conditional clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6, 7).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.pk = 4 THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED AND pk = 6 THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 401, "hr"), // update + Row(5, 500, "hr"), // unchanged + Row(6, 0, "new"))) // insert + } + } + + test("merge with one NOT MATCHED BY SOURCE clause") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(1, 2).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED BY SOURCE THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), // unchanged + Row(2, 200, "software"))) // unchanged + } + } + + test("merge with one conditional NOT MATCHED BY SOURCE clause") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(2).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN + | UPDATE SET salary = -1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + } + + test("merge with MATCHED and NOT MATCHED BY SOURCE clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(2).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | DELETE + |WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN + | UPDATE SET salary = -1 + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(3, 300, "hr"))) // unchanged + } + } + + test("merge with NOT MATCHED and NOT MATCHED BY SOURCE clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(2, 3, 4).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (pk, -1, 'new') + |WHEN NOT MATCHED BY SOURCE THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, -1, "new"))) // insert + } + } + + test("merge with multiple NOT MATCHED BY SOURCE clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(5, 6, 7).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED BY SOURCE AND salary = 100 THEN + | UPDATE SET salary = salary + 1 + |WHEN NOT MATCHED BY SOURCE AND salary = 300 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"))) // unchanged + } + } + + test("merge with MATCHED BY SOURCE clause and NULL values") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": null, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "id": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceRows = Seq( + (2, 2, 201, "support"), + (1, 1, 101, "support"), + (3, 3, 301, "support")) + sourceRows.toDF("pk", "id", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.id = s.id AND t.id < 3 + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, null, 100, "hr"), // unchanged + Row(2, 2, 201, "support"), // update + Row(3, 3, 300, "hr"))) // unchanged + } + } + + test("merge with CTE") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (0, 101, "support"), + (2, 301, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""WITH cte1 AS (SELECT pk + 1 as pk, salary, dep FROM source) + |MERGE INTO $tableNameAsString AS t + |USING cte1 AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), // unchanged + Row(2, 200, "software"))) // unchanged + } + } + + test("merge with subquery as source") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 6, "salary": 600, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (2, 201, "support"), + (1, 101, "support"), + (3, 301, "support"), + (6, 601, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + val subquery = + s""" + |SELECT * FROM source WHERE pk = 2 + |UNION ALL + |SELECT * FROM source WHERE pk = 1 OR pk = 6 + |""".stripMargin + + sql( + s"""MERGE INTO $tableNameAsString AS t + |USING ($subquery) AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), // update + Row(2, 201, "support"), // insert + Row(6, 601, "support"))) // update + } + } + + test("merge cardinality check with conditional MATCHED clause (delete)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 6, "salary": 600, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (1, 102, "support"), + (2, 201, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN MATCHED AND s.salary = 101 THEN + | DELETE + |""".stripMargin) + } + } + + test("merge cardinality check with unconditional MATCHED clause (delete)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 6, "salary": 600, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (1, 102, "support"), + (2, 201, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(6, 600, "software"))) // unchanged + } + } + + test("merge cardinality check with only NOT MATCHED clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 6, "salary": 600, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (1, 102, "support"), + (2, 201, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), // unchanged + Row(2, 201, "support"), // insert + Row(6, 600, "software"))) // unchanged + } + } + + test("merge cardinality check with small target and large source (broadcast enabled)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = (1 to 1000).map(pk => (pk, pk * 100, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING (SELECT * FROM source UNION ALL SELECT * FROM source) AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assertNoLeftBroadcastOrReplication( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assert(sql(s"SELECT * FROM $tableNameAsString").count() == 2) + } + } + } + + test("merge cardinality check with small target and large source (broadcast disabled)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = (1 to 1000).map(pk => (pk, pk * 100, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING (SELECT * FROM source UNION ALL SELECT * FROM source) AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assertNoLeftBroadcastOrReplication( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assert(sql(s"SELECT * FROM $tableNameAsString").count() == 2) + } + } + } + + test("merge cardinality check with small target and large source (shuffle hash enabled)") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = (1 to 1000).map(pk => (pk, pk * 100, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING (SELECT * FROM source UNION ALL SELECT * FROM source) AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assertNoLeftBroadcastOrReplication( + s"""MERGE INTO $tableNameAsString AS t + |USING source AS s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assert(sql(s"SELECT * FROM $tableNameAsString").count() == 2) + } + } + } + + test("merge cardinality check without equality condition and only MATCHED clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = (1 to 1000).map(pk => (pk, pk * 100, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING (SELECT * FROM source UNION ALL SELECT * FROM source) AS s + |ON t.pk > s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |""".stripMargin) + + assert(sql(s"SELECT * FROM $tableNameAsString").count() == 2) + } + } + } + + test("merge cardinality check without equality condition") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = (1 to 1000).map(pk => (pk, pk * 100, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + assertCardinalityError( + s"""MERGE INTO $tableNameAsString AS t + |USING (SELECT * FROM source UNION ALL SELECT * FROM source) AS s + |ON t.pk > s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + assert(sql(s"SELECT * FROM $tableNameAsString").count() == 2) + } + } + } + + test("self merge") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql( + s"""MERGE INTO $tableNameAsString t + |USING $tableNameAsString s + |ON t.pk = s.pk + |WHEN MATCHED AND t.salary = 100 THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + test("merge with self subquery") { + withTempView("ids") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + Seq(1, 2).toDF("value").createOrReplaceTempView("ids") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING (SELECT pk FROM $tableNameAsString r JOIN ids ON r.pk = ids.value) s + |ON t.pk = s.pk + |WHEN MATCHED AND t.salary = 100 THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (dep, salary, pk) VALUES ('new', 300, 1) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + } + + test("merge with extra columns in source") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceRows = Seq( + (1, "smth", 101, "support"), + (2, "smth", 201, "support"), + (4, "smth", 401, "support")) + sourceRows.toDF("pk", "extra", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, s.dep) + |WHEN NOT MATCHED BY SOURCE THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 102, "hr"), // update + Row(2, 202, "software"), // update + Row(4, 401, "support"))) // insert + } + } + + test("merge with NULL values in target and source") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": null, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (5, None, 501, "support"), + (6, Some(6), 601, "support")) + sourceRows.toDF("pk", "id", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.id = s.id + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, null, 100, "hr"), // unchanged + Row(2, 2, 200, "software"), // unchanged + Row(5, null, 501, "support"), // insert + Row(6, 6, 601, "support"))) // insert + } + } + + test("merge with <=>") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": null, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (5, None, 501, "support"), + (6, Some(6), 601, "support")) + sourceRows.toDF("pk", "id", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.id <=> s.id + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 2, 200, "software"), // unchanged + Row(5, null, 501, "support"), // updated + Row(6, 6, 601, "support"))) // insert + } + } + + test("merge with NULL ON condition") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": null, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (5, None, 501, "support"), + (6, Some(2), 201, "support")) + sourceRows.toDF("pk", "id", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.id = s.id AND NULL + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, null, 100, "hr"), // unchanged + Row(2, 2, 200, "software"), // unchanged + Row(5, null, 501, "support"), // new + Row(6, 2, 201, "support"))) // new + } + } + + test("merge with NULL clause conditions") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (3, 301, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND NULL THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED AND NULL THEN + | INSERT * + |WHEN NOT MATCHED BY SOURCE AND NULL THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), // unchanged + Row(2, 200, "software"))) // unchanged + } + } + + test("merge with multiple matching clauses") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + (1, 101, "support"), + (3, 301, "support")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.pk = 1 THEN + | UPDATE SET salary = t.salary + 5 + |WHEN MATCHED AND t.salary = 100 THEN + | UPDATE SET salary = t.salary + 2 + |WHEN NOT MATCHED BY SOURCE AND t.pk = 2 THEN + | UPDATE SET salary = salary - 1 + |WHEN NOT MATCHED BY SOURCE AND t.salary = 200 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 105, "hr"), // updated (matched) + Row(2, 199, "software"))) // updated (not matched by source) + } + } + + test("merge resolves and aligns columns by name") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceRows = Seq( + ("support", 1, 101), + ("support", 3, 301)) + sourceRows.toDF("dep", "pk", "salary").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), // update + Row(2, 200, "software"), // unchanged + Row(3, 301, "support"))) // insert + } + } + + test("merge refreshed relation cache") { + withTempView("temp", "source") { + withCache("temp") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 100, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // define a view on top of the table + val query = sql(s"SELECT * FROM $tableNameAsString WHERE salary = 100") + query.createOrReplaceTempView("temp") + + // cache the view + sql("CACHE TABLE temp") + + // verify the view returns expected results + checkAnswer( + sql("SELECT * FROM temp"), + Row(1, 100, "hr") :: Row(2, 100, "software") :: Nil) + + val sourceRows = Seq( + ("support", 1, 101), + ("support", 3, 301)) + sourceRows.toDF("dep", "pk", "salary").createOrReplaceTempView("source") + + // merge changes into the table + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // verify the merge was successful + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), // update + Row(2, 100, "software"), // unchanged + Row(3, 301, "support"))) // insert + + // verify the view reflects the changes in the table + checkAnswer(sql("SELECT * FROM temp"), Row(2, 100, "software") :: Nil) + } + } + } + + test("merge with updates to nested struct fields in MATCHED clauses") { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING""".stripMargin, + """{ "pk": 1, "s": { "c1": 2, "c2": { "a": [1,2], "m": { "a": "b" } } }, "dep": "hr" }""") + + Seq(1, 3).toDF("pk").createOrReplaceTempView("source") + + // update primitive, array, map columns inside a struct + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1) + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"))), "hr"))) + + // set primitive, array, map columns to NULL (proper casts should be in inserted) + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s.c1 = NULL, s.c2 = NULL + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, Row(null, null), "hr") :: Nil) + + // assign an entire struct + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null)) + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, Row(1, Row(Seq(1), null)), "hr") :: Nil) + } + } + + test("merge with updates to nested struct fields in NOT MATCHED BY SOURCE clauses") { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING""".stripMargin, + """{ "pk": 1, "s": { "c1": 2, "c2": { "a": [1,2], "m": { "a": "b" } } }, "dep": "hr" }""") + + Seq(2, 4).toDF("pk").createOrReplaceTempView("source") + + // update primitive, array, map columns inside a struct + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1) + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"))), "hr"))) + + // set primitive, array, map columns to NULL (proper casts should be in inserted) + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET s.c1 = NULL, s.c2 = NULL + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, Row(null, null), "hr") :: Nil) + + // assign an entire struct + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null)) + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, Row(1, Row(Seq(1), null)), "hr") :: Nil) + } + } + + test("merge with default values") { + withTempView("source") { + val idDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType)) + val columns = Array( + Column.create("pk", IntegerType, false, null, null), + Column.create("id", IntegerType, true, null, idDefault, null), + Column.create("dep", StringType, true, null, null)) + + createTable(columns) + + append("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + Seq(1, 2, 4).toDF("pk").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET id = DEFAULT + |WHEN NOT MATCHED THEN + | INSERT (pk, id, dep) VALUES (s.pk, DEFAULT, 'new') + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET id = DEFAULT + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 42, "hr"), // update (matched) + Row(2, 42, "software"), // update (matched) + Row(3, 42, "hr"), // update (not matched by source) + Row(4, 42, "new"))) // insert + } + } + + test("merge with char/varchar columns") { + withTempView("source") { + createTable("pk INT NOT NULL, s STRUCT, dep STRING") + + append("pk INT NOT NULL, s STRUCT, dep STRING", + """{ "pk": 1, "s": { "n_c": "aaa", "n_vc": "aaa" }, "dep": "hr" } + |{ "pk": 2, "s": { "n_c": "bbb", "n_vc": "bbb" }, "dep": "software" } + |{ "pk": 3, "s": { "n_c": "ccc", "n_vc": "ccc" }, "dep": "hr" } + |""".stripMargin) + + Seq(1, 2, 4).toDF("pk").createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET s.n_c = 'x1', s.n_vc = 'x2' + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET s.n_c = 'y1', s.n_vc = 'y2' + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("x1 ", "x2"), "hr"), // update (matched) + Row(2, Row("x1 ", "x2"), "software"), // update (matched) + Row(3, Row("y1 ", "y2"), "hr"))) // update (not matched by source) + } + } + + test("merge with NOT NULL checks") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, s STRUCT, dep STRING", + """{ "pk": 1, "s": { "n_i": 1, "n_l": 11 }, "dep": "hr" } + |{ "pk": 2, "s": { "n_i": 2, "n_l": 22 }, "dep": "software" } + |{ "pk": 3, "s": { "n_i": 3, "n_l": 33 }, "dep": "hr" } + |""".stripMargin) + + Seq(1, 4).toDF("pk").createOrReplaceTempView("source") + + val e1 = intercept[SparkException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) + |""".stripMargin) + } + assert(e1.getCause.getMessage.contains("Null value appeared in non-nullable field")) + + val e2 = intercept[SparkException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) + |""".stripMargin) + } + assert(e2.getCause.getMessage.contains("Null value appeared in non-nullable field")) + + val e3 = intercept[SparkException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, s, dep) VALUES (s.pk, named_struct('n_i', null, 'n_l', -1L), 'invalid') + |""".stripMargin) + } + assert(e3.getCause.getMessage.contains("Null value appeared in non-nullable field")) + } + } + + private def assertNoLeftBroadcastOrReplication(query: String): Unit = { + val plan = executeAndKeepPlan { + sql(query) + } + assertNoLeftBroadcastOrReplication(plan) + } + + private def assertNoLeftBroadcastOrReplication(plan: SparkPlan): Unit = { + val joins = plan.collect { + case j: BroadcastHashJoinExec if j.buildSide == BuildLeft => j + case j: BroadcastNestedLoopJoinExec if j.buildSide == BuildLeft => j + case j: CartesianProductExec => j + } + assert(joins.isEmpty, "left side must not be broadcasted or replicated") + } + + private def assertCardinalityError(query: String): Unit = { + val e = intercept[SparkException] { + sql(query) + } + assert(e.getCause.getMessage.contains("ON search condition of the MERGE statement")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index e330dcc0aaccd..0cb94709898ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -25,10 +25,12 @@ import org.apache.spark.sql.{DataFrame, Encoders, QueryTest} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.QueryExecutionListener abstract class RowLevelOperationSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { @@ -92,4 +94,24 @@ abstract class RowLevelOperationSuiteBase spark.read.schema(schemaString).json(jsonDS) } } + + // executes an operation and keeps the executed plan + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { + var executedPlan: SparkPlan = null + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + executedPlan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + } + } + spark.listenerManager.register(listener) + + func + + sparkContext.listenerBus.waitUntilEmpty() + + stripAQEPlan(executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index 97a531da4f619..66a986da936ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -196,11 +196,16 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { } class CustomAnalyzer(catalogManager: CatalogManager) extends Analyzer(catalogManager) { + + private val ignoredRuleNames = Set( + "org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable", + "org.apache.spark.sql.catalyst.analysis.RewriteMergeIntoTable") + override def batches: Seq[Batch] = { val defaultBatches = super.batches defaultBatches.map { batch => val filteredRules = batch.rules.filterNot { rule => - rule.ruleName == "org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable" + ignoredRuleNames.contains(rule.ruleName) } Batch(batch.name, batch.strategy, filteredRules: _*) }