Skip to content

Commit

Permalink
Support idempotent write using SQL options for INSERTS/DELETE/UPDATE/…
Browse files Browse the repository at this point in the history
…MERGE

Currently, delta supports idempotent write using Dataframe writer options. These writer options are applicable to inserts only. This PR adds support for idempotency using SQL options(DELTA_IDEMPOTENT_DML_TXN_APP_ID and DELTA_IDEMPOTENT_DML_TXN_VERSION) to INSERTS/DELETE/UPDATE/MERGE etc. When both writer options and SQL conf are specified, we will use the writer option values.

Idempotent write works by checking the txnVersion and txnAppId from user-provided write options or from session configurations(as a SQL conf). If the same or higher txnVersion has been recorded, then it will skip the write.

Add unit tests to test out the idempotency.
  • Loading branch information
zedtang authored and scottsand-db committed Jan 11, 2023
1 parent 5fe54d0 commit 57d68b3
Show file tree
Hide file tree
Showing 8 changed files with 696 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ case class CreateDeltaTableCommand(
tableByPath: Boolean = false,
override val output: Seq[Attribute] = Nil)
extends LeafRunnableCommand
with DeltaCommand
with DeltaLogging {

override def run(sparkSession: SparkSession): Seq[Row] = {
Expand Down Expand Up @@ -125,15 +126,17 @@ case class CreateDeltaTableCommand(
// we are creating a table as part of a RunnableCommand
query.get match {
case writer: WriteIntoDelta =>
// In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that
// the metadata should be changed. This wasn't the behavior for DataFrameWriterV1.
if (!isV1Writer) {
replaceMetadataIfNecessary(
txn, tableWithLocation, options, writer.data.schema.asNullable)
if (!hasBeenExecuted(txn, sparkSession, Some(options))) {
// In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that
// the metadata should be changed. This wasn't the behavior for DataFrameWriterV1.
if (!isV1Writer) {
replaceMetadataIfNecessary(
txn, tableWithLocation, options, writer.data.schema.asNullable)
}
val actions = writer.write(txn, sparkSession)
val op = getOperation(txn.metadata, isManagedTable, Some(options))
txn.commit(actions, op)
}
val actions = writer.write(txn, sparkSession)
val op = getOperation(txn.metadata, isManagedTable, Some(options))
txn.commit(actions, op)
case cmd: RunnableCommand =>
result = cmd.run(sparkSession)
case other =>
Expand All @@ -156,8 +159,10 @@ case class CreateDeltaTableCommand(
configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull),
data = data).write(txn, sparkSession)

val op = getOperation(txn.metadata, isManagedTable, Some(options))
txn.commit(actions, op)
if (!hasBeenExecuted(txn, sparkSession, Some(options))) {
val op = getOperation(txn.metadata, isManagedTable, Some(options))
txn.commit(actions, op)
}
}
} else {
def createTransactionLogOrVerify(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, Expression, If, Literal, Not}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric
import org.apache.spark.sql.functions.{input_file_name, lit, typedLit, udf}

Expand Down Expand Up @@ -85,6 +84,11 @@ case class DeleteCommand(
recordDeltaOperation(deltaLog, "delta.dml.delete") {
deltaLog.assertRemovable()
deltaLog.withNewTransaction { txn =>
if (hasBeenExecuted(txn, sparkSession)) {
sendDriverMetrics(sparkSession, metrics)
return Seq.empty
}

val deleteActions = performDelete(sparkSession, deltaLog, txn)
if (deleteActions.nonEmpty) {
txn.commit(deleteActions, DeltaOperations.Delete(condition.map(_.sql).toSeq))
Expand Down Expand Up @@ -277,10 +281,7 @@ case class DeleteCommand(
numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set)
numCopiedRows.foreach(metrics("numCopiedRows").set)
txn.registerSQLMetrics(sparkSession, metrics)
// This is needed to make the SQL metrics visible in the Spark UI
val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
sparkSession.sparkContext, executionId, metrics.values.toSeq)
sendDriverMetrics(sparkSession, metrics)

recordDeltaEvent(
deltaLog,
Expand Down Expand Up @@ -309,7 +310,11 @@ case class DeleteCommand(
rewriteTimeMs)
)

deleteActions
if (deleteActions.nonEmpty) {
createSetTransaction(sparkSession, deltaLog).toSeq ++ deleteActions
} else {
Seq.empty
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import scala.util.control.NonFatal

import org.apache.spark.sql.delta.{CommitStats, DeltaErrors, DeltaLog, DeltaOperations, DeltaTableIdentifier, OptimisticTransaction, Serializable, Snapshot}
import org.apache.spark.sql.delta.{CommitStats, DeltaErrors, DeltaLog, DeltaOperations, DeltaOptions, DeltaTableIdentifier, OptimisticTransaction, Serializable, Snapshot}
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.files.TahoeBatchFileIndex
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand All @@ -39,7 +39,9 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -423,4 +425,107 @@ trait DeltaCommand extends DeltaLogging {
deltaLog
}

/**
* Send the driver-side metrics.
*
* This is needed to make the SQL metrics visible in the Spark UI.
* All metrics are default initialized with 0 so that's what we're
* reporting in case we skip an already executed action.
*/
protected def sendDriverMetrics(spark: SparkSession, metrics: Map[String, SQLMetric]): Unit = {
val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq)
}

/**
* Returns true if there is information in the spark session that indicates that this write
* has already been successfully written.
*/
protected def hasBeenExecuted(txn: OptimisticTransaction, sparkSession: SparkSession,
options: Option[DeltaOptions] = None): Boolean = {
val (txnVersionOpt, txnAppIdOpt, isFromSessionConf) = getTxnVersionAndAppId(
sparkSession, options)
// only enter if both txnVersion and txnAppId are set
for (version <- txnVersionOpt; appId <- txnAppIdOpt) {
val currentVersion = txn.txnVersion(appId)
if (currentVersion >= version) {
logInfo(s"Already completed batch $version in application $appId. This will be skipped.")
if (isFromSessionConf && sparkSession.sessionState.conf.getConf(
DeltaSQLConf.DELTA_IDEMPOTENT_DML_AUTO_RESET_ENABLED)) {
// if we got txnAppId and txnVersion from the session config, we reset the
// version here, after skipping the current transaction, as a safety measure to
// prevent data loss if the user forgets to manually reset txnVersion
sparkSession.sessionState.conf.unsetConf(DeltaSQLConf.DELTA_IDEMPOTENT_DML_TXN_VERSION)
}
return true
}
}
false
}

/**
* Returns SetTransaction if a valid app ID and version are present. Otherwise returns
* an empty list.
*/
protected def createSetTransaction(
sparkSession: SparkSession,
deltaLog: DeltaLog,
options: Option[DeltaOptions] = None): Option[SetTransaction] = {
val (txnVersionOpt, txnAppIdOpt, isFromSessionConf) = getTxnVersionAndAppId(
sparkSession, options)
// only enter if both txnVersion and txnAppId are set
for (version <- txnVersionOpt; appId <- txnAppIdOpt) {
if (isFromSessionConf && sparkSession.sessionState.conf.getConf(
DeltaSQLConf.DELTA_IDEMPOTENT_DML_AUTO_RESET_ENABLED)) {
// if we got txnAppID and txnVersion from the session config, we reset the
// version here as a safety measure to prevent data loss if the user forgets
// to manually reset txnVersion
sparkSession.sessionState.conf.unsetConf(DeltaSQLConf.DELTA_IDEMPOTENT_DML_TXN_VERSION)
}
return Some(SetTransaction(appId, version, Some(deltaLog.clock.getTimeMillis())))
}
None
}

/**
* Helper method to retrieve the current txn version and app ID. These are either
* retrieved from user-provided write options or from session configurations.
*/
private def getTxnVersionAndAppId(
sparkSession: SparkSession,
options: Option[DeltaOptions]): (Option[Long], Option[String], Boolean) = {
var txnVersion: Option[Long] = None
var txnAppId: Option[String] = None
for (o <- options) {
txnVersion = o.txnVersion
txnAppId = o.txnAppId
}

var numOptions = txnVersion.size + txnAppId.size
// numOptions can only be 0 or 2, as enforced by
// DeltaWriteOptionsImpl.validateIdempotentWriteOptions so this
// assert should never be triggered
assert(numOptions == 0 || numOptions == 2, s"Only one of txnVersion and txnAppId " +
s"has been set via dataframe writer options: txnVersion = $txnVersion txnAppId = $txnAppId")
var fromSessionConf = false
if (numOptions == 0) {
txnVersion = sparkSession.sessionState.conf.getConf(
DeltaSQLConf.DELTA_IDEMPOTENT_DML_TXN_VERSION)
// don't need to check for valid conversion to Long here as that
// is already enforced at set time
txnAppId = sparkSession.sessionState.conf.getConf(
DeltaSQLConf.DELTA_IDEMPOTENT_DML_TXN_APP_ID)
// check that both session configs are set
numOptions = txnVersion.size + txnAppId.size
if (numOptions != 0 && numOptions != 2) {
throw DeltaErrors.invalidIdempotentWritesOptionsException(
"Both spark.databricks.delta.write.txnAppId and " +
"spark.databricks.delta.write.txnVersion must be specified for " +
"idempotent Delta writes")
}
fromSessionConf = true
}
(txnVersion, txnAppId, fromSessionConf)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -297,6 +296,10 @@ case class MergeIntoCommand(
recordDeltaOperation(targetDeltaLog, "delta.dml.merge") {
val startTime = System.nanoTime()
targetDeltaLog.withNewTransaction { deltaTxn =>
if (hasBeenExecuted(deltaTxn, spark)) {
sendDriverMetrics(spark, metrics)
return Seq.empty
}
if (target.schema.size != deltaTxn.metadata.schema.size) {
throw DeltaErrors.schemaChangedSinceAnalysis(
atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema)
Expand All @@ -321,6 +324,7 @@ case class MergeIntoCommand(
}
}

val finalActions = createSetTransaction(spark, targetDeltaLog).toSeq ++ deltaActions
// Metrics should be recorded before commit (where they are written to delta logs).
metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000)
deltaTxn.registerSQLMetrics(spark, metrics)
Expand All @@ -336,7 +340,7 @@ case class MergeIntoCommand(
}

deltaTxn.commit(
deltaActions,
finalActions,
DeltaOperations.Merge(
Option(condition.sql),
matchedClauses.map(DeltaOperations.MergePredicate(_)),
Expand All @@ -351,10 +355,7 @@ case class MergeIntoCommand(
}
spark.sharedState.cacheManager.recacheByPlan(spark, target)
}
// This is needed to make the SQL metrics visible in the Spark UI. Also this needs
// to be outside the recordMergeOperation because this method will update some metric.
val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq)
sendDriverMetrics(spark, metrics)
Seq.empty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, If, Literal}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric
import org.apache.spark.sql.functions.{array, col, explode, input_file_name, lit, struct, typedLit, udf}

Expand Down Expand Up @@ -73,6 +72,10 @@ case class UpdateCommand(
val deltaLog = tahoeFileIndex.deltaLog
deltaLog.assertRemovable()
deltaLog.withNewTransaction { txn =>
if (hasBeenExecuted(txn, sparkSession)) {
sendDriverMetrics(sparkSession, metrics)
return Seq.empty[Row]
}
performUpdate(sparkSession, deltaLog, txn)
}
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to
Expand Down Expand Up @@ -200,11 +203,9 @@ case class UpdateCommand(
metrics("numTouchedRows").value - metrics("numUpdatedRows").value)
}
txn.registerSQLMetrics(sparkSession, metrics)
txn.commit(totalActions, DeltaOperations.Update(condition.map(_.toString)))
// This is needed to make the SQL metrics visible in the Spark UI
val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
sparkSession.sparkContext, executionId, metrics.values.toSeq)
val finalActions = createSetTransaction(sparkSession, deltaLog).toSeq ++ totalActions
txn.commit(finalActions, DeltaOperations.Update(condition.map(_.toString)))
sendDriverMetrics(sparkSession, metrics)
}

recordDeltaEvent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ case class WriteIntoDelta(

override def run(sparkSession: SparkSession): Seq[Row] = {
deltaLog.withNewTransaction { txn =>
// If this batch has already been executed within this query, then return.
var skipExecution = hasBeenExecuted(txn)
if (skipExecution) {
if (hasBeenExecuted(txn, sparkSession, Some(options))) {
return Seq.empty
}

Expand Down Expand Up @@ -277,8 +275,7 @@ case class WriteIntoDelta(
} else {
newFiles ++ deletedFiles
}
var setTxns = createSetTransaction()
setTxns.toSeq ++ fileActions
createSetTransaction(sparkSession, deltaLog, Some(options)).toSeq ++ fileActions
}

private def extractConstraints(
Expand Down Expand Up @@ -312,36 +309,4 @@ case class WriteIntoDelta(
spark.sessionState.analyzer.checkAnalysis(command)
command.asInstanceOf[DeleteCommand].performDelete(spark, txn.deltaLog, txn)
}

/**
* Returns true if there is information in the spark session that indicates that this write, which
* is part of a streaming query and a batch, has already been successfully written.
*/
private def hasBeenExecuted(txn: OptimisticTransaction): Boolean = {
val txnVersion = options.txnVersion
val txnAppId = options.txnAppId
for (v <- txnVersion; a <- txnAppId) {
val currentVersion = txn.txnVersion(a)
if (currentVersion >= v) {
logInfo(s"Transaction write of version $v for application id $a " +
s"has already been committed in Delta table id ${txn.deltaLog.tableId}. " +
s"Skipping this write.")
return true
}
}
false
}

/**
* Returns SetTransaction if a valid app ID and version are present. Otherwise returns
* an empty list.
*/
private def createSetTransaction(): Option[SetTransaction] = {
val txnVersion = options.txnVersion
val txnAppId = options.txnAppId
for (v <- txnVersion; a <- txnAppId) {
return Some(SetTransaction(a, v, Some(deltaLog.clock.getTimeMillis())))
}
None
}
}
Loading

0 comments on commit 57d68b3

Please sign in to comment.