Skip to content

Commit

Permalink
some naming and formatting comments from Anish and Jungtaek
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 26, 2024
1 parent 1a23abb commit 97ee3ef
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class StatePartitionReader(
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)

case Some(snapshotStartBatchId) =>
provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(
provider.asInstanceOf[SupportsFineGrainedReplayFromSnapshot].replayReadStateFromSnapshot(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ import org.apache.spark.util.ArrayImplicits._
* store.
*/
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging
with FineGrainedStateSource {
with SupportsFineGrainedReplayFromSnapshot {

private val providerName = "HDFSBackedStateStoreProvider"

Expand Down Expand Up @@ -262,22 +262,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
new HDFSBackedStateStore(version, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* startVersion. If snapshot for startVersion does not exist, an error will be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
new HDFSBackedStateStore(endVersion, newMap)
}

/** Get the state store for reading to specific `version` of the store. */
override def getReadStore(version: Long): ReadStateStore = {
val newMap = getLoadedMapForStore(version)
Expand All @@ -286,22 +270,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
new HDFSBackedReadStateStore(version, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* startVersion. If snapshot for startVersion does not exist, an error will be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedReadStateStore]]
*/
override def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly")
new HDFSBackedReadStateStore(endVersion, newMap)
}

private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
try {
if (version < 0) {
Expand All @@ -318,33 +286,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

/**
* Consturct the state at endVersion from snapshot from startVersion.
* Returns a new [[HDFSBackedStateStoreMap]]
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
private def replayLoadedMapForStoreFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion || endVersion < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (endVersion != 0) {
newMap.putAll(constructMapFromSnapshot(startVersion, endVersion))
}
newMap
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

// Run bunch of validations specific to HDFSBackedStateStoreProvider
private def runValidation(
useColumnFamilies: Boolean,
Expand Down Expand Up @@ -604,33 +545,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
result
}

private def constructMapFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = {
val (result, elapsedMs) = Utils.timeTakenMs {
val startVersionMap = synchronized { Option(loadedMaps.get(startVersion)) } match {
case Some(value) => Option(value)
case None => readSnapshotFile(startVersion)
}
if (startVersionMap.isEmpty) {
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
snapshotFile(startVersion).toString, toString())
}

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
resultMap.putAll(startVersionMap.get)
for (deltaVersion <- startVersion + 1 to endVersion) {
updateFromDeltaFile(deltaVersion, resultMap)
}

resultMap
}

logDebug(s"Loading state from $startVersion to $endVersion takes $elapsedMs ms.")

result
}

private def writeUpdateToDeltaFile(
output: DataOutputStream,
key: UnsafeRow,
Expand Down Expand Up @@ -975,4 +889,90 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
throw new IllegalStateException(msg)
}
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* startVersion. If snapshot for startVersion does not exist, an error will be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def replayStateFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
new HDFSBackedStateStore(endVersion, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* startVersion. If snapshot for startVersion does not exist, an error will be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedReadStateStore]]
*/
override def replayReadStateFromSnapshot(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly")
new HDFSBackedReadStateStore(endVersion, newMap)
}

/**
* Consturct the state at endVersion from snapshot from startVersion.
* Returns a new [[HDFSBackedStateStoreMap]]
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
private def replayLoadedMapForStoreFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion || endVersion < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (endVersion != 0) {
newMap.putAll(constructMapFromSnapshot(startVersion, endVersion))
}
newMap
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

private def constructMapFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = {
val (result, elapsedMs) = Utils.timeTakenMs {
val startVersionMap = synchronized { Option(loadedMaps.get(startVersion)) } match {
case Some(value) => Option(value)
case None => readSnapshotFile(startVersion)
}
if (startVersionMap.isEmpty) {
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
snapshotFile(startVersion).toString, toString())
}

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
resultMap.putAll(startVersionMap.get)
for (deltaVersion <- startVersion + 1 to endVersion) {
updateFromDeltaFile(deltaVersion, resultMap)
}

resultMap
}

logDebug(s"Loading state from $startVersion to $endVersion takes $elapsedMs ms.")

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

private[sql] class RocksDBStateStoreProvider
extends StateStoreProvider with Logging with Closeable with FineGrainedStateSource {
extends StateStoreProvider with Logging with Closeable
with SupportsFineGrainedReplayFromSnapshot {
import RocksDBStateStoreProvider._

class RocksDBStateStore(lastVersion: Long) extends StateStore {
Expand Down Expand Up @@ -309,51 +310,19 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def getReadStore(version: Long): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(version, readOnly = true)
rocksDB.load(version, true)
new RocksDBStateStore(version)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def doMaintenance(): Unit = {
try {
rocksDB.doMaintenance()
Expand Down Expand Up @@ -399,6 +368,38 @@ private[sql] class RocksDBStateStoreProvider
private def verify(condition: => Boolean, msg: String): Unit = {
if (!condition) { throw new IllegalStateException(msg) }
}

override def replayStateFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def replayReadStateFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}
}

object RocksDBStateStoreProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,32 +444,32 @@ object StateStoreProvider {
* grained state data which is replayed from a specific snapshot version. It is used by the
* snapshotStartBatchId option in state data source.
*/
trait FineGrainedStateSource {
trait SupportsFineGrainedReplayFromSnapshot {
/**
* Used by snapshotStartBatchId option when reading state generated by join operation as data
* source.
* Return an instance of [[StateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
* The State Store will be constructed from the snapshot at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file at startVersion, an exception will be
* thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore
def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore

/**
* Used by snapshotStartBatchId option when reading state generated by all stateful operations
* except join as data source.
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
* The State Store will be constructed from the snapshot at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file at startVersion, an exception will be
* thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): ReadStateStore
def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ class SymmetricHashJoinStateManager(
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false)
if (snapshotStartVersion.isDefined) {
stateStoreProvider.asInstanceOf[FineGrainedStateSource]
.replayStoreFromSnapshot(snapshotStartVersion.get, stateInfo.get.storeVersion)
stateStoreProvider.asInstanceOf[SupportsFineGrainedReplayFromSnapshot]
.replayStateFromSnapshot(snapshotStartVersion.get, stateInfo.get.storeVersion)
} else {
stateStoreProvider.getStore(stateInfo.get.storeVersion)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}

val exc = intercept[SparkException] {
provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(1, 2)
provider.asInstanceOf[SupportsFineGrainedReplayFromSnapshot].replayReadStateFromSnapshot(1, 2)
}
checkError(exc, "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED")
})
Expand All @@ -1001,7 +1001,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
provider.doMaintenance()
}

val result = provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(2, 3)
val result = provider.asInstanceOf[SupportsFineGrainedReplayFromSnapshot].replayReadStateFromSnapshot(2, 3)

assert(get(result, "a", 1).get == 1)
assert(get(result, "a", 2).get == 2)
Expand Down

0 comments on commit 97ee3ef

Please sign in to comment.