From 6db0e3dd38b089efe691fe4f4c33880ea2995580 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 4 Jun 2024 15:28:49 -0700 Subject: [PATCH] initial implementation --- .../resources/error/error-conditions.json | 11 +++ .../sql/errors/QueryExecutionErrors.scala | 19 ++++ .../v2/state/StateDataSource.scala | 29 +++++- .../v2/state/StatePartitionReader.scala | 9 +- .../v2/state/StateScanBuilder.scala | 17 +++- .../state/HDFSBackedStateStoreProvider.scala | 91 +++++++++++++++++++ .../state/RocksDBStateStoreProvider.scala | 15 +++ .../streaming/state/StateStore.scala | 11 +++ state-store-content-check.py | 33 +++++++ 9 files changed, 229 insertions(+), 6 deletions(-) create mode 100644 state-store-content-check.py diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 66708649e5646..5c703b6e8de2a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -223,6 +223,11 @@ "Error reading snapshot file of : key size cannot be ." ] }, + "CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS" : { + "message" : [ + "Error reading snapshot file of : does not exist." + ] + }, "CANNOT_READ_SNAPSHOT_FILE_VALUE_SIZE" : { "message" : [ "Error reading snapshot file of : value size cannot be ." @@ -3594,6 +3599,12 @@ ], "sqlState" : "42K08" }, + "SNAPSHOT_PARTITION_ID_NOT_FOUND" : { + "message" : [ + "Partition id not found for given state source." + ], + "sqlState" : "54054" + }, "SORT_BY_WITHOUT_BUCKETING" : { "message" : [ "sortBy must be used together with bucketBy." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 1f3283ebed059..414c38d3446a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2159,6 +2159,18 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } + def failedToReadSnapshotFileNotExistsError( + fileToRead: Path, + clazz: String, + f: Throwable): Throwable = { + new SparkException( + errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS", + messageParameters = Map( + "fileToRead" -> fileToRead.toString(), + "clazz" -> clazz), + cause = f) + } + def failedToReadSnapshotFileValueSizeError( fileToRead: Path, clazz: String, @@ -2186,6 +2198,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = f) } + def snapshotPartitionNotFoundError(snapshotPartitionId : Long): Throwable = { + new SparkException( + errorClass = "SNAPSHOT_PARTITION_ID_NOT_FOUND", + messageParameters = Map("snapshotPartitionId" -> snapshotPartitionId.toString()), + cause = null) + } + def cannotPurgeAsBreakInternalStateError(): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException(errorClass = "_LEGACY_ERROR_TEMP_2260") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 1a8f444042c23..a941f3a9bf143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -116,7 +116,9 @@ case class StateSourceOptions( batchId: Long, operatorId: Int, storeName: String, - joinSide: JoinSideValues) { + joinSide: JoinSideValues, + snapshotStartBatchId: Option[Long], + snapshotPartitionId: Option[Int]) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -131,6 +133,8 @@ object StateSourceOptions extends DataSourceOptions { val OPERATOR_ID = newOption("operatorId") val STORE_NAME = newOption("storeName") val JOIN_SIDE = newOption("joinSide") + val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId") + val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -190,7 +194,28 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } - StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide) + val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong) + if (snapshotStartBatchId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) + } else if (snapshotStartBatchId.exists(_ > batchId)) { + throw StateDataSourceErrors.invalidOptionValue( + SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId") + } + + val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt) + if (snapshotPartitionId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) + } + + if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID.toString) + } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID.toString) + } + + StateSourceOptions( + resolvedCpLocation, batchId, operatorId, storeName, + joinSide, snapshotStartBatchId, snapshotPartitionId) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index bbfe3a3f373ec..a6ea426993fd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -93,7 +93,14 @@ class StatePartitionReader( } private lazy val store: ReadStateStore = { - provider.getReadStore(partition.sourceOptions.batchId + 1) + if (partition.sourceOptions.snapshotStartBatchId.isEmpty) { + provider.getReadStore(partition.sourceOptions.batchId + 1) + } + else { + provider.getReadStore( + partition.sourceOptions.snapshotStartBatchId.get + 1, + partition.sourceOptions.batchId + 1) + } } private lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 0d69bf708e94f..838797e6dc122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.StateStoreConf @@ -81,9 +82,19 @@ class StateScan( assert((tail - head + 1) == partitionNums.length, s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}") - partitionNums.map { - pn => new StateStoreInputPartition(pn, queryId, sourceOptions) - }.toArray + if (sourceOptions.snapshotPartitionId.isEmpty) { + partitionNums.map { + pn => new StateStoreInputPartition(pn, queryId, sourceOptions) + }.toArray + } + else { + val snapshotPartitionId = sourceOptions.snapshotPartitionId.get + if (partitionNums.contains(snapshotPartitionId)) { + Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions)) + } else { + throw QueryExecutionErrors.snapshotPartitionNotFoundError(snapshotPartitionId) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 543cd74c489d0..42fbd8602adb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -269,6 +269,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with new HDFSBackedReadStateStore(version, newMap) } + override def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore = { + val newMap = getLoadedMapForStore(startVersion, endVersion) + logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " + + log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly") + new HDFSBackedReadStateStore(startVersion, newMap) + } + private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized { try { if (version < 0) { @@ -285,6 +293,27 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } } + private def getLoadedMapForStore(startVersion: Long, endVersion: Long): + HDFSBackedStateStoreMap = synchronized { + try { + if (startVersion < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion) + } + if (endVersion < startVersion || endVersion < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion) + } + + val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + if (!(startVersion == 0 && endVersion == 0)) { + newMap.putAll(loadMap(startVersion, endVersion)) + } + newMap + } + catch { + case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) + } + } + // Run bunch of validations specific to HDFSBackedStateStoreProvider private def runValidation( useColumnFamilies: Boolean, @@ -544,6 +573,62 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with result } + private def loadMap(startVersion: Long, endVersion: Long): HDFSBackedStateStoreMap = { + + val (result, elapsedMs) = Utils.timeTakenMs { + // val snapshotCurrentVersionMap = readSnapshotFile(version) + // if (snapshotCurrentVersionMap.isDefined) { + // synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + // return snapshotCurrentVersionMap.get + // } + + // // Find the most recent map before this version that we can. + // // [SPARK-22305] This must be done iteratively to avoid stack overflow. + // var lastAvailableVersion = version + // var lastAvailableMap: Option[HDFSBackedStateStoreMap] = None + // while (lastAvailableMap.isEmpty) { + // lastAvailableVersion -= 1 + + // if (lastAvailableVersion <= 0) { + // // Use an empty map for versions 0 or less. + // lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)) + // } else { + // lastAvailableMap = + // synchronized { Option(loadedMaps.get(lastAvailableVersion)) } + // .orElse(readSnapshotFile(lastAvailableVersion)) + // } + // } + + val startVersionMap = + synchronized { Option(loadedMaps.get(startVersion)) } + .orElse{ + logWarning( + log"The state for version ${MDC(LogKeys.FILE_VERSION, startVersion)} doesn't " + + log"exist in loadedMaps. Reading snapshot file and delta files if needed..." + + log"Note that this is normal for the first batch of starting query.") + readSnapshotFile(startVersion)} + if (startVersionMap.isEmpty) { + throw QueryExecutionErrors.failedToReadSnapshotFileNotExistsError( + snapshotFile(startVersion), toString(), null) + } + synchronized { putStateIntoStateCacheMap(startVersion, startVersionMap.get)} + + // Load all the deltas from the version after the start version up to the end version. + val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + resultMap.putAll(startVersionMap.get) + for (deltaVersion <- startVersion + 1 to endVersion) { + updateFromDeltaFile(deltaVersion, resultMap) + } + + resultMap + } + + // todo + logDebug(s"Loading state from $startVersion to $endVersion takes $elapsedMs ms.") + + result + } + private def writeUpdateToDeltaFile( output: DataOutputStream, key: UnsafeRow, @@ -683,6 +768,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } } + /** + * try to read the snapshot file of the given version. + * If the snapshot file is not available, return None. + * + * @param version the version of the + */ private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = { val fileToRead = snapshotFile(version) val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e7fc9f56dd9eb..b35a69fca487e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -309,6 +309,7 @@ private[sql] class RocksDBStateStoreProvider } } + // todo override def getReadStore(version: Long): StateStore = { try { if (version < 0) { @@ -322,6 +323,20 @@ private[sql] class RocksDBStateStoreProvider } } + // todo + override def getReadStore(version: Long, endVersion: Long): StateStore = { + try { + if (version < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + } + rocksDB.load(version, true) + new RocksDBStateStore(version) + } + catch { + case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) + } + } + override def doMaintenance(): Unit = { try { rocksDB.doMaintenance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8c2170abe3116..d6eeb7b6e7a17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -378,6 +378,17 @@ trait StateStoreProvider { def getReadStore(version: Long): ReadStateStore = new WrappedReadStateStore(getStore(version)) + /** + * Return an instance of [[ReadStateStore]] representing state data of the given version. + * The State Store will be constructed from the batch at startVersion, and applying delta files + * up to the endVersion. If there is no snapshot file of batch startVersion, an exception will + * be thrown. + * + * @param startVersion Batch ID of the snapshot to start with + * @param endVersion Batch ID to end with + */ + def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore + /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } diff --git a/state-store-content-check.py b/state-store-content-check.py new file mode 100644 index 0000000000000..da319406c36e5 --- /dev/null +++ b/state-store-content-check.py @@ -0,0 +1,33 @@ +from pyspark.sql.functions import window, col + +# aggregate operator +q1 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds")\ + .groupBy(window("timestamp", "10 seconds")).count().select("window.start", "window.end", "count")\ + .writeStream.format("memory").queryName("window").option("checkpointLocation", "/tmp/state/window").start() + +# join operator +sdf1 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds") +sdf2 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().withWatermark("timestamp", "50 seconds") +q2 = sdf1.join(sdf2, "timestamp").select()\ + .writeStream.format("memory").queryName("join").option("checkpointLocation", "/tmp/state/join").start() + +# limit operator +q3 = spark.readStream.format("rate").option("rowsPerSecond", 100).load().limit(20)\ + .writeStream.format("console").queryName("limit").option("checkpointLocation", "/tmp/state/limit").start() + + +# rm -rf /tmp/state/window + +# read from state source +meta1 = spark.read.format("state-metadata").load("/tmp/state/window") +state1 = spark.read.format("statestore").load("/tmp/state/window") + +state1_1 = spark.read.format("statestore").option("snapshotStartBatchId", 20).option("snapshotPartitionId", 40).load("/tmp/state/window").show() +state1_2 = spark.read.format("statestore").option("batchId", 53).load("/tmp/state/window").show() + +meta2 = spark.read.format("state-metadata").load("/tmp/state/join") +state2_1 = spark.read.format("statestore").option("storeName", "left-keyToNumValues").load("/tmp/state/join") +state2_2 = spark.read.format("statestore").option("joinSide", "left").load("/tmp/state/join") + +meta3 = spark.read.format("state-metadata").load("/tmp/state/limit") +state3 = spark.read.format("statestore").load("/tmp/state/limit")