Skip to content

Commit

Permalink
[SPARK-48589][SQL][SS] Add option snapshotStartBatchId and snapshotPa…
Browse files Browse the repository at this point in the history
…rtitionId to state data source

### What changes were proposed in this pull request?

This PR defines two new options, snapshotStartBatchId and snapshotPartitionId, for the existing state reader. Both of them should be provided at the same time.
1. When there is no snapshot file at `snapshotStartBatch` (note there is an off-by-one issue between version and batch Id), throw an exception.
2. Otherwise, the reader should continue to rebuild the state by reading delta files only, and ignore all snapshot files afterwards.
3. Note that if a `batchId` option is already specified. That batchId is the ending batchId, we should then end at that batchId.
4. This feature supports state generated by HDFS state store provider and RocksDB state store provider with changelog checkpointing enabled. **It does not support RocksDB with changelog disabled which is the default for RocksDB.**

### Why are the changes needed?

Sometimes when a snapshot is corrupted, users want to bypass it when reading a later state. This PR gives user ability to specify the starting snapshot version and partition. This feature can be useful for debugging purpose.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Created test cases for testing edge cases for the input of new options. Created test for the new public function `replayReadStateFromSnapshot`. Created integration test for the new options against four stateful operators: limit, aggregation, deduplication, stream-stream join. Instead of generating states within the tests which is unstable, I prepare golden files for the integration test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46944 from eason-yuchen-liu/skipSnapshotAtBatch.

Lead-authored-by: Yuchen Liu <yuchen.liu@databricks.com>
Co-authored-by: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
2 people authored and HeartSaVioR committed Jul 2, 2024
1 parent db9e1ac commit ee0d306
Show file tree
Hide file tree
Showing 894 changed files with 1,046 additions and 24 deletions.
17 changes: 17 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@
"Error reading delta file <fileToRead> of <clazz>: <fileToRead> does not exist."
]
},
"CANNOT_READ_MISSING_SNAPSHOT_FILE" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: <fileToRead> does not exist."
]
},
"CANNOT_READ_SNAPSHOT_FILE_KEY_SIZE" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: key size cannot be <keySize>."
Expand All @@ -251,6 +256,11 @@
"Error reading streaming state file of <fileToRead> does not exist. If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location."
]
},
"SNAPSHOT_PARTITION_ID_NOT_FOUND" : {
"message" : [
"Partition id <snapshotPartitionId> not found for state of operator <operatorId> at <checkpointLocation>."
]
},
"UNCATEGORIZED" : {
"message" : [
""
Expand Down Expand Up @@ -3799,6 +3809,13 @@
],
"sqlState" : "42802"
},
"STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : {
"message" : [
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.",
"Therefore, it does not support option snapshotStartBatchId in state data source."
],
"sqlState" : "42K06"
},
"STATE_STORE_UNSUPPORTED_OPERATION" : {
"message" : [
"<operationType> operation not supported with <entity>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,16 @@ case class StateSourceOptions(
batchId: Long,
operatorId: Int,
storeName: String,
joinSide: JoinSideValues) {
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide)"
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " +
s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})"
}
}

Expand All @@ -131,6 +135,8 @@ object StateSourceOptions extends DataSourceOptions {
val OPERATOR_ID = newOption("operatorId")
val STORE_NAME = newOption("storeName")
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -190,7 +196,30 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}

StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide)
val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
if (snapshotStartBatchId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
} else if (snapshotStartBatchId.exists(_ > batchId)) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId")
}

val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt)
if (snapshotPartitionId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
}

// both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
// each partition may have different checkpoint status
if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID)
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
}

StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -93,7 +93,19 @@ class StatePartitionReader(
}

private lazy val store: ReadStateStore = {
provider.getReadStore(partition.sourceOptions.batchId + 1)
partition.sourceOptions.snapshotStartBatchId match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)

case Some(snapshotStartBatchId) =>
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
provider.getClass.toString)
}
provider.asInstanceOf[SupportsFineGrainedReplay]
.replayReadStateFromSnapshot(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
}
}

private lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -81,9 +81,20 @@ class StateScan(
assert((tail - head + 1) == partitionNums.length,
s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}")

partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray
sourceOptions.snapshotPartitionId match {
case None => partitionNums.map { pn =>
new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray

case Some(snapshotPartitionId) =>
if (partitionNums.contains(snapshotPartitionId)) {
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
} else {
throw StateStoreErrors.stateStoreSnapshotPartitionNotFound(
snapshotPartitionId, sourceOptions.operatorId,
sourceOptions.stateCheckpointLocation.toString)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ class StateTable(
}

override def name(): String = {
val desc = s"StateTable " +
var desc = s"StateTable " +
s"[stateCkptLocation=${sourceOptions.stateCheckpointLocation}]" +
s"[batchId=${sourceOptions.batchId}][operatorId=${sourceOptions.operatorId}]" +
s"[storeName=${sourceOptions.storeName}]"

if (sourceOptions.joinSide != JoinSideValues.none) {
desc + s"[joinSide=${sourceOptions.joinSide}]"
} else {
desc
desc += s"[joinSide=${sourceOptions.joinSide}]"
}
if (sourceOptions.snapshotStartBatchId.isDefined) {
desc += s"[snapshotStartBatchId=${sourceOptions.snapshotStartBatchId}]"
}
if (sourceOptions.snapshotPartitionId.isDefined) {
desc += s"[snapshotPartitionId=${sourceOptions.snapshotPartitionId}]"
}
desc
}

override def capabilities(): util.Set[TableCapability] = CAPABILITY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class StreamStreamJoinStatePartitionReader(
partitionId = partition.partition,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false
useStateStoreCoordinator = false,
snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ import org.apache.spark.util.ArrayImplicits._
* to ensure re-executed RDD operations re-apply updates on the correct past version of the
* store.
*/
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging
with SupportsFineGrainedReplay {

private val providerName = "HDFSBackedStateStoreProvider"

Expand Down Expand Up @@ -683,6 +684,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

/**
* Try to read the snapshot file. If the snapshot file is not available, return [[None]].
*
* @param version the version of the snapshot file
*/
private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
val fileToRead = snapshotFile(version)
val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
Expand Down Expand Up @@ -883,4 +889,93 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
throw new IllegalStateException(msg)
}
}

/**
* Get the state store of endVersion by applying delta files on the snapshot of snapshotVersion.
* If snapshot for snapshotVersion does not exist, an error will be thrown.
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = {
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
new HDFSBackedStateStore(endVersion, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* snapshotVersion. If snapshot for snapshotVersion does not exist, an error will be thrown.
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedReadStateStore]]
*/
override def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long):
ReadStateStore = {
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for read-only")
new HDFSBackedReadStateStore(endVersion, newMap)
}

/**
* Construct the state map at endVersion from snapshot of version snapshotVersion.
* Returns a new [[HDFSBackedStateStoreMap]]
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
private def replayLoadedMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (snapshotVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion)
}
if (endVersion < snapshotVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
newMap.putAll(constructMapFromSnapshot(snapshotVersion, endVersion))

newMap
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

private def constructMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = {
val (result, elapsedMs) = Utils.timeTakenMs {
val startVersionMap = synchronized { Option(loadedMaps.get(snapshotVersion)) } match {
case Some(value) => Option(value)
case None => readSnapshotFile(snapshotVersion)
}
if (startVersionMap.isEmpty) {
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
snapshotFile(snapshotVersion).toString, toString())
}

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
resultMap.putAll(startVersionMap.get)
for (deltaVersion <- snapshotVersion + 1 to endVersion) {
updateFromDeltaFile(deltaVersion, resultMap)
}

resultMap
}

logDebug(s"Loading snapshot at version $snapshotVersion and apply delta files to version " +
s"$endVersion takes $elapsedMs ms.")

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,80 @@ class RocksDB(
this
}

/**
* Load from the start snapshot version and apply all the changelog records to reach the
* end version. Note that this will copy all the necessary files from DFS to local disk as needed,
* and possibly restart the native RocksDB instance.
*
* @param snapshotVersion version of the snapshot to start with
* @param endVersion end version
* @return A RocksDB instance loaded with the state endVersion replayed from snapshotVersion.
* Note that the instance will be read-only since this method is only used in State Data
* Source.
*/
def loadFromSnapshot(snapshotVersion: Long, endVersion: Long): RocksDB = {
assert(snapshotVersion >= 0 && endVersion >= snapshotVersion)
acquire(LoadStore)
recordedMetrics = None
logInfo(
log"Loading snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
try {
replayFromCheckpoint(snapshotVersion, endVersion)

logInfo(
log"Loaded snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
} catch {
case t: Throwable =>
loadedVersion = -1 // invalidate loaded data
throw t
}
this
}

/**
* Load from the start checkpoint version and apply all the changelog records to reach the
* end version.
* If the start version does not exist, it will throw an exception.
*
* @param snapshotVersion start checkpoint version
* @param endVersion end version
*/
private def replayFromCheckpoint(snapshotVersion: Long, endVersion: Long): Any = {
closeDB()
val metadata = fileManager.loadCheckpointFromDfs(snapshotVersion, workingDir)
loadedVersion = snapshotVersion

// reset last snapshot version
if (lastSnapshotVersion > snapshotVersion) {
// discard any newer snapshots
lastSnapshotVersion = 0L
latestSnapshot = None
}
openDB()

numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) {
// we don't track the total number of rows - discard the number being track
-1L
} else if (metadata.numKeys < 0) {
// we track the total number of rows, but the snapshot doesn't have tracking number
// need to count keys now
countKeys()
} else {
metadata.numKeys
}
if (loadedVersion != endVersion) replayChangelog(endVersion)
// After changelog replay the numKeysOnWritingVersion will be updated to
// the correct number of keys in the loaded version.
numKeysOnLoadedVersion = numKeysOnWritingVersion
fileManagerMetrics = fileManager.latestLoadCheckpointMetrics

if (conf.resetStatsOnLoad) {
nativeStats.reset
}
}

/**
* Replay change log from the loaded version to the target version.
*/
Expand Down
Loading

0 comments on commit ee0d306

Please sign in to comment.