Skip to content

Commit

Permalink
initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 4, 2024
1 parent a7da9b6 commit 6db0e3d
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 6 deletions.
11 changes: 11 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@
"Error reading snapshot file <fileToRead> of <clazz>: key size cannot be <keySize>."
]
},
"CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: <fileToRead> does not exist."
]
},
"CANNOT_READ_SNAPSHOT_FILE_VALUE_SIZE" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: value size cannot be <valueSize>."
Expand Down Expand Up @@ -3594,6 +3599,12 @@
],
"sqlState" : "42K08"
},
"SNAPSHOT_PARTITION_ID_NOT_FOUND" : {
"message" : [
"Partition id <snapshotPartitionId> not found for given state source."
],
"sqlState" : "54054"
},
"SORT_BY_WITHOUT_BUCKETING" : {
"message" : [
"sortBy must be used together with bucketBy."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,18 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
cause = null)
}

def failedToReadSnapshotFileNotExistsError(
fileToRead: Path,
clazz: String,
f: Throwable): Throwable = {
new SparkException(
errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS",
messageParameters = Map(
"fileToRead" -> fileToRead.toString(),
"clazz" -> clazz),
cause = f)
}

def failedToReadSnapshotFileValueSizeError(
fileToRead: Path,
clazz: String,
Expand Down Expand Up @@ -2186,6 +2198,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
cause = f)
}

def snapshotPartitionNotFoundError(snapshotPartitionId : Long): Throwable = {
new SparkException(
errorClass = "SNAPSHOT_PARTITION_ID_NOT_FOUND",
messageParameters = Map("snapshotPartitionId" -> snapshotPartitionId.toString()),
cause = null)
}

def cannotPurgeAsBreakInternalStateError(): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(errorClass = "_LEGACY_ERROR_TEMP_2260")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ case class StateSourceOptions(
batchId: Long,
operatorId: Int,
storeName: String,
joinSide: JoinSideValues) {
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand All @@ -131,6 +133,8 @@ object StateSourceOptions extends DataSourceOptions {
val OPERATOR_ID = newOption("operatorId")
val STORE_NAME = newOption("storeName")
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -190,7 +194,28 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}

StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide)
val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
if (snapshotStartBatchId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
} else if (snapshotStartBatchId.exists(_ > batchId)) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId")
}

val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt)
if (snapshotPartitionId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
}

if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID.toString)
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID.toString)
}

StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,14 @@ class StatePartitionReader(
}

private lazy val store: ReadStateStore = {
provider.getReadStore(partition.sourceOptions.batchId + 1)
if (partition.sourceOptions.snapshotStartBatchId.isEmpty) {
provider.getReadStore(partition.sourceOptions.batchId + 1)
}
else {
provider.getReadStore(
partition.sourceOptions.snapshotStartBatchId.get + 1,
partition.sourceOptions.batchId + 1)
}
}

private lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{Path, PathFilter}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
Expand Down Expand Up @@ -81,9 +82,19 @@ class StateScan(
assert((tail - head + 1) == partitionNums.length,
s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}")

partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray
if (sourceOptions.snapshotPartitionId.isEmpty) {
partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray
}
else {
val snapshotPartitionId = sourceOptions.snapshotPartitionId.get
if (partitionNums.contains(snapshotPartitionId)) {
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
} else {
throw QueryExecutionErrors.snapshotPartitionNotFoundError(snapshotPartitionId)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
new HDFSBackedReadStateStore(version, newMap)
}

override def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = getLoadedMapForStore(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly")
new HDFSBackedReadStateStore(startVersion, newMap)
}

private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
try {
if (version < 0) {
Expand All @@ -285,6 +293,27 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

private def getLoadedMapForStore(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (startVersion < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion || endVersion < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (!(startVersion == 0 && endVersion == 0)) {
newMap.putAll(loadMap(startVersion, endVersion))
}
newMap
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

// Run bunch of validations specific to HDFSBackedStateStoreProvider
private def runValidation(
useColumnFamilies: Boolean,
Expand Down Expand Up @@ -544,6 +573,62 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
result
}

private def loadMap(startVersion: Long, endVersion: Long): HDFSBackedStateStoreMap = {

val (result, elapsedMs) = Utils.timeTakenMs {
// val snapshotCurrentVersionMap = readSnapshotFile(version)
// if (snapshotCurrentVersionMap.isDefined) {
// synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) }
// return snapshotCurrentVersionMap.get
// }

// // Find the most recent map before this version that we can.
// // [SPARK-22305] This must be done iteratively to avoid stack overflow.
// var lastAvailableVersion = version
// var lastAvailableMap: Option[HDFSBackedStateStoreMap] = None
// while (lastAvailableMap.isEmpty) {
// lastAvailableVersion -= 1

// if (lastAvailableVersion <= 0) {
// // Use an empty map for versions 0 or less.
// lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
// } else {
// lastAvailableMap =
// synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
// .orElse(readSnapshotFile(lastAvailableVersion))
// }
// }

val startVersionMap =
synchronized { Option(loadedMaps.get(startVersion)) }
.orElse{
logWarning(
log"The state for version ${MDC(LogKeys.FILE_VERSION, startVersion)} doesn't " +
log"exist in loadedMaps. Reading snapshot file and delta files if needed..." +
log"Note that this is normal for the first batch of starting query.")
readSnapshotFile(startVersion)}
if (startVersionMap.isEmpty) {
throw QueryExecutionErrors.failedToReadSnapshotFileNotExistsError(
snapshotFile(startVersion), toString(), null)
}
synchronized { putStateIntoStateCacheMap(startVersion, startVersionMap.get)}

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
resultMap.putAll(startVersionMap.get)
for (deltaVersion <- startVersion + 1 to endVersion) {
updateFromDeltaFile(deltaVersion, resultMap)
}

resultMap
}

// todo
logDebug(s"Loading state from $startVersion to $endVersion takes $elapsedMs ms.")

result
}

private def writeUpdateToDeltaFile(
output: DataOutputStream,
key: UnsafeRow,
Expand Down Expand Up @@ -683,6 +768,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

/**
* try to read the snapshot file of the given version.
* If the snapshot file is not available, return None.
*
* @param version the version of the
*/
private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
val fileToRead = snapshotFile(version)
val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ private[sql] class RocksDBStateStoreProvider
}
}

// todo
override def getReadStore(version: Long): StateStore = {
try {
if (version < 0) {
Expand All @@ -322,6 +323,20 @@ private[sql] class RocksDBStateStoreProvider
}
}

// todo
override def getReadStore(version: Long, endVersion: Long): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(version, true)
new RocksDBStateStore(version)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def doMaintenance(): Unit = {
try {
rocksDB.doMaintenance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,17 @@ trait StateStoreProvider {
def getReadStore(version: Long): ReadStateStore =
new WrappedReadStateStore(getStore(version))

/**
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion Batch ID of the snapshot to start with
* @param endVersion Batch ID to end with
*/
def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore

/** Optional method for providers to allow for background maintenance (e.g. compactions) */
def doMaintenance(): Unit = { }

Expand Down
33 changes: 33 additions & 0 deletions state-store-content-check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pyspark.sql.functions import window, col

# aggregate operator
q1 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds")\
.groupBy(window("timestamp", "10 seconds")).count().select("window.start", "window.end", "count")\
.writeStream.format("memory").queryName("window").option("checkpointLocation", "/tmp/state/window").start()

# join operator
sdf1 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds")
sdf2 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds")
q2 = sdf1.join(sdf2, "timestamp").select()\
.writeStream.format("memory").queryName("join").option("checkpointLocation", "/tmp/state/join").start()

# limit operator
q3 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().limit(20)\
.writeStream.format("console").queryName("limit").option("checkpointLocation", "/tmp/state/limit").start()


# rm -rf /tmp/state/window

# read from state source
meta1 = spark.read.format("state-metadata").load("/tmp/state/window")
state1 = spark.read.format("statestore").load("/tmp/state/window")

state1_1 = spark.read.format("statestore").option("snapshotStartBatchId", 20).option("snapshotPartitionId", 40).load("/tmp/state/window").show()
state1_2 = spark.read.format("statestore").option("batchId", 53).load("/tmp/state/window").show()

meta2 = spark.read.format("state-metadata").load("/tmp/state/join")
state2_1 = spark.read.format("statestore").option("storeName", "left-keyToNumValues").load("/tmp/state/join")
state2_2 = spark.read.format("statestore").option("joinSide", "left").load("/tmp/state/join")

meta3 = spark.read.format("state-metadata").load("/tmp/state/limit")
state3 = spark.read.format("statestore").load("/tmp/state/limit")

0 comments on commit 6db0e3d

Please sign in to comment.