From 4999469173df8480da873cd718c07122616f5ee6 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Thu, 25 Jul 2024 21:42:05 +0900 Subject: [PATCH] [SPARK-48849][SS] Create OperatorStateMetadataV2 for the TransformWithStateExec operator ### What changes were proposed in this pull request? Introducing the OperatorStateMetadataV2 format that integrates with the TransformWithStateExec operator. This is used to keep information about the TWS operator, will be used to enforce invariants in between query runs. Each OperatorStateMetadataV2 has a pointer to the StateSchemaV3 file for the corresponding operator. Will introduce purging in this PR: https://github.com/apache/spark/pull/47286 ### Why are the changes needed? This is needed for State Metadata integration with the TransformWithState operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Added unit tests to StateStoreSuite and TransformWithStateSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #47445 from ericm-db/metadata-v2. Authored-by: Eric Marnadi Signed-off-by: Jungtaek Lim --- .../state/metadata/StateMetadataSource.scala | 65 +++-- .../streaming/IncrementalExecution.scala | 23 +- .../StreamingSymmetricHashJoinExec.scala | 6 +- .../streaming/TransformWithStateExec.scala | 51 +++- .../state/OperatorStateMetadata.scala | 230 +++++++++++++++--- .../StateSchemaCompatibilityChecker.scala | 39 +-- .../streaming/statefulOperators.scala | 29 ++- .../state/OperatorStateMetadataSuite.scala | 33 +-- ...StateSchemaCompatibilityCheckerSuite.scala | 35 ++- .../streaming/TransformWithStateSuite.scala | 99 +++++++- 10 files changed, 508 insertions(+), 102 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 893984feabf11..0024ef1a5cae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, StateStoreMetadataV1} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,6 +46,7 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, + operatorPropertiesJson: String, numColsPrefixKey: Int) { def toRow(): InternalRow = { new GenericInternalRow( @@ -55,6 +56,7 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, + UTF8String.fromString(operatorPropertiesJson), numColsPrefixKey)) } } @@ -68,6 +70,7 @@ object StateMetadataTableEntry { .add("numPartitions", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) + .add("operatorProperties", StringType) } } @@ -188,29 +191,59 @@ class StateMetadataPartitionReader( } else Array.empty } - private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + // Need this to be accessible from IncrementalExecution for the planning rule. + private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = { val stateDir = new Path(checkpointLocation, "state") val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted opIds.map { opId => - new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + val operatorIdPath = new Path(stateDir, opId.toString) + // check if OperatorStateMetadataV2 path exists, if it does, read it + // otherwise, fall back to OperatorStateMetadataV1 + val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath) + val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) { + 2 + } else { + 1 + } + OperatorStateMetadataReader.createReader( + operatorIdPath, hadoopConf, operatorStateMetadataVersion).read() match { + case Some(metadata) => metadata + case None => OperatorStateMetadataV1(OperatorInfoV1(opId, null), + Array(StateStoreMetadataV1(null, -1, -1))) + } } } private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => - require(operatorStateMetadata.version == 1) - val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] - operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, - operatorStateMetadataV1.operatorInfo.operatorName, - stateStoreMetadata.storeName, - stateStoreMetadata.numPartitions, - if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1, - stateStoreMetadata.numColsPrefixKey - ) + require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) + operatorStateMetadata match { + case v1: OperatorStateMetadataV1 => + v1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v1.operatorInfo.operatorId, + v1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + null, + stateStoreMetadata.numColsPrefixKey + ) + } + case v2: OperatorStateMetadataV2 => + v2.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v2.operatorInfo.operatorId, + v2.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + v2.operatorPropertiesJson, + -1 // numColsPrefixKey is not available in OperatorStateMetadataV2 + ) + } + } } - } - }.iterator + }.iterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 722a3bd86b7e1..567fb1b98f14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -208,13 +208,16 @@ class IncrementalExecution( } val schemaValidationResult = statefulOp. validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion) + val stateSchemaPaths = schemaValidationResult.map(_.schemaPath) // write out the state schema paths to the metadata file statefulOp match { - case stateStoreWriter: StateStoreWriter => - val metadata = stateStoreWriter.operatorStateMetadata() - // TODO: [SPARK-48849] Populate metadata with stateSchemaPaths if metadata version is v2 - val metadataWriter = new OperatorStateMetadataWriter(new Path( - checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf) + case ssw: StateStoreWriter => + val metadata = ssw.operatorStateMetadata(stateSchemaPaths) + val metadataWriter = OperatorStateMetadataWriter.createWriter( + new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString), + hadoopConf, + ssw.operatorStateMetadataVersion, + Some(currentBatchId)) metadataWriter.write(metadata) case _ => } @@ -456,8 +459,12 @@ class IncrementalExecution( val reader = new StateMetadataPartitionReader( new Path(checkpointLocation).getParent.toString, new SerializableConfiguration(hadoopConf)) - ret = reader.stateMetadata.map { metadataTableEntry => - metadataTableEntry.operatorId -> metadataTableEntry.operatorName + val opMetadataList = reader.allOperatorStateMetadata + ret = opMetadataList.map { + case OperatorStateMetadataV1(operatorInfo, _) => + operatorInfo.operatorId -> operatorInfo.operatorName + case OperatorStateMetadataV2(operatorInfo, _, _) => + operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index a303d4db66a01..c54917bdb7873 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec( private val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray + val stateStoreInfo = + stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index a4d525ad13fd4..d2b8f92aa985b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.JString +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -96,6 +100,8 @@ case class TransformWithStateExec( } } + override def operatorStateMetadataVersion: Int = 2 + /** * We initialize this processor handle in the driver to run the init function * and fetch the schemas of the state variables initialized in this processor. @@ -382,12 +388,47 @@ case class TransformWithStateExec( batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { assert(stateSchemaVersion >= 3) - val newColumnFamilySchemas = getColFamilySchemas() + val newSchemas = getColFamilySchemas() val stateSchemaDir = stateSchemaDirPath() - val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}") - List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion, - schemaFilePath = Some(stateSchemaFilePath))) + val newStateSchemaFilePath = + new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}") + val metadataPath = new Path(getStateInfo.checkpointLocation, s"${getStateInfo.operatorId}") + val metadataReader = OperatorStateMetadataReader.createReader( + metadataPath, hadoopConf, operatorStateMetadataVersion) + val operatorStateMetadata = metadataReader.read() + val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match { + case Some(metadata) => + metadata match { + case v2: OperatorStateMetadataV2 => + Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath)) + case _ => None + } + case None => None + } + List(StateSchemaCompatibilityChecker. + validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newSchemas.values.toList, session.sessionState, stateSchemaVersion, + storeName = StateStoreId.DEFAULT_STORE_NAME, + oldSchemaFilePath = oldStateSchemaFilePath, + newSchemaFilePath = Some(newStateSchemaFilePath))) + } + + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata( + stateSchemaPaths: List[String]): OperatorStateMetadata = { + val info = getStateInfo + val operatorInfo = OperatorInfoV1(info.operatorId, shortName) + // stateSchemaFilePath should be populated at this point + val stateStoreInfo = + Array(StateStoreMetadataV2( + StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head)) + + val operatorPropertiesJson: JValue = + ("timeMode" -> JString(timeMode.toString)) ~ + ("outputMode" -> JString(outputMode.toString)) + + val json = compact(render(operatorPropertiesJson)) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } private def stateSchemaDirPath(): Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index dcea29085bf2b..df3de5d9ceab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -20,13 +20,17 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{BufferedReader, InputStreamReader} import java.nio.charset.StandardCharsets +import scala.reflect.ClassTag + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream +import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter} /** * Metadata for a state store instance. @@ -40,6 +44,21 @@ trait StateStoreMetadata { case class StateStoreMetadataV1(storeName: String, numColsPrefixKey: Int, numPartitions: Int) extends StateStoreMetadata +case class StateStoreMetadataV2( + storeName: String, + numColsPrefixKey: Int, + numPartitions: Int, + stateSchemaFilePath: String) + extends StateStoreMetadata with Serializable + +object StateStoreMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[StateStoreMetadataV2](implicitly[ClassTag[StateStoreMetadataV2]].runtimeClass) +} + /** * Information about a stateful operator. */ @@ -51,7 +70,10 @@ trait OperatorInfo { case class OperatorInfoV1(operatorId: Long, operatorName: String) extends OperatorInfo trait OperatorStateMetadata { + def version: Int + + def operatorInfo: OperatorInfo } case class OperatorStateMetadataV1( @@ -60,12 +82,56 @@ case class OperatorStateMetadataV1( override def version: Int = 1 } -object OperatorStateMetadataUtils { +case class OperatorStateMetadataV2( + operatorInfo: OperatorInfoV1, + stateStoreInfo: Array[StateStoreMetadataV2], + operatorPropertiesJson: String) extends OperatorStateMetadata { + override def version: Int = 2 +} + +object OperatorStateMetadataUtils extends Logging { + + sealed trait OperatorStateMetadataReader { + def version: Int + + def read(): Option[OperatorStateMetadata] + } + + sealed trait OperatorStateMetadataWriter { + def version: Int + def write(operatorMetadata: OperatorStateMetadata): Unit + } private implicit val formats: Formats = Serialization.formats(NoTypeHints) - def metadataFilePath(stateCheckpointPath: Path): Path = - new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + def readMetadata(inputStream: FSDataInputStream): Option[OperatorStateMetadata] = { + val inputReader = + new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + try { + val versionStr = inputReader.readLine() + val version = MetadataVersionUtil.validateVersion(versionStr, 2) + Some(deserialize(version, inputReader)) + } finally { + inputStream.close() + } + } + + def writeMetadata( + outputStream: CancellableFSDataOutputStream, + operatorMetadata: OperatorStateMetadata, + metadataFilePath: Path): Unit = { + try { + outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8)) + OperatorStateMetadataUtils.serialize(outputStream, operatorMetadata) + outputStream.close() + } catch { + case e: Throwable => + logError( + log"Fail to write state metadata file to ${MDC(LogKeys.META_FILE, metadataFilePath)}", e) + outputStream.cancel() + throw e + } + } def deserialize( version: Int, @@ -73,6 +139,8 @@ object OperatorStateMetadataUtils { version match { case 1 => Serialization.read[OperatorStateMetadataV1](in) + case 2 => + Serialization.read[OperatorStateMetadataV2](in) case _ => throw new IllegalArgumentException(s"Failed to deserialize operator metadata with " + @@ -86,7 +154,8 @@ object OperatorStateMetadataUtils { operatorStateMetadata.version match { case 1 => Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1], out) - + case 2 => + Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out) case _ => throw new IllegalArgumentException(s"Failed to serialize operator metadata with " + s"version=${operatorStateMetadata.version}") @@ -94,54 +163,153 @@ object OperatorStateMetadataUtils { } } +object OperatorStateMetadataReader { + def createReader( + stateCheckpointPath: Path, + hadoopConf: Configuration, + version: Int): OperatorStateMetadataReader = { + version match { + case 1 => + new OperatorStateMetadataV1Reader(stateCheckpointPath, hadoopConf) + case 2 => + new OperatorStateMetadataV2Reader(stateCheckpointPath, hadoopConf) + case _ => + throw new IllegalArgumentException(s"Failed to create reader for operator metadata " + + s"with version=$version") + } + } +} + +object OperatorStateMetadataWriter { + def createWriter( + stateCheckpointPath: Path, + hadoopConf: Configuration, + version: Int, + currentBatchId: Option[Long] = None): OperatorStateMetadataWriter = { + version match { + case 1 => + new OperatorStateMetadataV1Writer(stateCheckpointPath, hadoopConf) + case 2 => + if (currentBatchId.isEmpty) { + throw new IllegalArgumentException("currentBatchId is required for version 2") + } + new OperatorStateMetadataV2Writer(stateCheckpointPath, hadoopConf, currentBatchId.get) + case _ => + throw new IllegalArgumentException(s"Failed to create writer for operator metadata " + + s"with version=$version") + } + } +} + +object OperatorStateMetadataV1 { + def metadataFilePath(stateCheckpointPath: Path): Path = + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") +} + +object OperatorStateMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) + + def metadataDirPath(stateCheckpointPath: Path): Path = + new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2") + + def metadataFilePath(stateCheckpointPath: Path, currentBatchId: Long): Path = + new Path(metadataDirPath(stateCheckpointPath), currentBatchId.toString) +} + /** * Write OperatorStateMetadata into the state checkpoint directory. */ -class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configuration) - extends Logging { +class OperatorStateMetadataV1Writer( + stateCheckpointPath: Path, + hadoopConf: Configuration) + extends OperatorStateMetadataWriter with Logging { - private val metadataFilePath = OperatorStateMetadataUtils.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + override def version: Int = 1 + def write(operatorMetadata: OperatorStateMetadata): Unit = { if (fm.exists(metadataFilePath)) return fm.mkdirs(metadataFilePath.getParent) val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false) - try { - outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8)) - OperatorStateMetadataUtils.serialize(outputStream, operatorMetadata) - outputStream.close() - } catch { - case e: Throwable => - logError( - log"Fail to write state metadata file to ${MDC(LogKeys.META_FILE, metadataFilePath)}", e) - outputStream.cancel() - throw e - } + OperatorStateMetadataUtils.writeMetadata(outputStream, operatorMetadata, metadataFilePath) } } /** - * Read OperatorStateMetadata from the state checkpoint directory. + * Read OperatorStateMetadata from the state checkpoint directory. This class will only be + * used to read OperatorStateMetadataV1. + * OperatorStateMetadataV2 will be read by the OperatorStateMetadataLog. */ -class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) { +class OperatorStateMetadataV1Reader( + stateCheckpointPath: Path, + hadoopConf: Configuration) extends OperatorStateMetadataReader { + override def version: Int = 1 - private val metadataFilePath = OperatorStateMetadataUtils.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) - def read(): OperatorStateMetadata = { + def read(): Option[OperatorStateMetadata] = { val inputStream = fm.open(metadataFilePath) - val inputReader = - new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - try { - val versionStr = inputReader.readLine() - val version = MetadataVersionUtil.validateVersion(versionStr, 1) - OperatorStateMetadataUtils.deserialize(version, inputReader) - } finally { - inputStream.close() + OperatorStateMetadataUtils.readMetadata(inputStream) + } +} + +class OperatorStateMetadataV2Writer( + stateCheckpointPath: Path, + hadoopConf: Configuration, + currentBatchId: Long) extends OperatorStateMetadataWriter { + + private val metadataFilePath = OperatorStateMetadataV2.metadataFilePath( + stateCheckpointPath, currentBatchId) + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + + override def version: Int = 2 + + override def write(operatorMetadata: OperatorStateMetadata): Unit = { + if (fm.exists(metadataFilePath)) return + + fm.mkdirs(metadataFilePath.getParent) + val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false) + OperatorStateMetadataUtils.writeMetadata(outputStream, operatorMetadata, metadataFilePath) + } +} + +class OperatorStateMetadataV2Reader( + stateCheckpointPath: Path, + hadoopConf: Configuration) extends OperatorStateMetadataReader { + + private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateCheckpointPath) + private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf) + + fm.mkdirs(metadataDirPath.getParent) + override def version: Int = 2 + + private def listBatches(): Array[Long] = { + if (!fm.exists(metadataDirPath)) { + return Array.empty + } + fm.list(metadataDirPath).map(_.getPath.getName.toLong).sorted + } + + override def read(): Option[OperatorStateMetadata] = { + val batches = listBatches() + if (batches.isEmpty) { + return None } + val lastBatchId = batches.last + val metadataFilePath = OperatorStateMetadataV2.metadataFilePath( + stateCheckpointPath, lastBatchId) + val inputStream = fm.open(metadataFilePath) + OperatorStateMetadataUtils.readMetadata(inputStream) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 3230098c74cd4..ca03de6f1ad3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -48,13 +48,14 @@ case class StateStoreColFamilySchema( class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, hadoopConf: Configuration, - schemaFilePath: Option[Path] = None) extends Logging { + oldSchemaFilePath: Option[Path] = None, + newSchemaFilePath: Option[Path] = None) extends Logging { - private val schemaFileLocation = if (schemaFilePath.isEmpty) { + private val schemaFileLocation = if (oldSchemaFilePath.isEmpty) { val storeCpLocation = providerId.storeId.storeCheckpointLocation() schemaFile(storeCpLocation) } else { - schemaFilePath.get + oldSchemaFilePath.get } private val fm = CheckpointFileManager.create(schemaFileLocation, hadoopConf) @@ -65,10 +66,6 @@ class StateSchemaCompatibilityChecker( val inStream = fm.open(schemaFileLocation) try { val versionStr = inStream.readUTF() - // Ensure that version 3 format has schema file path provided explicitly - if (versionStr == "v3" && schemaFilePath.isEmpty) { - throw new IllegalStateException("Schema file path is required for schema version 3") - } val schemaReader = SchemaReader.createSchemaReader(versionStr) schemaReader.read(inStream) } catch { @@ -98,7 +95,7 @@ class StateSchemaCompatibilityChecker( stateStoreColFamilySchema: List[StateStoreColFamilySchema], stateSchemaVersion: Int): Unit = { // Ensure that schema file path is passed explicitly for schema version 3 - if (stateSchemaVersion == 3 && schemaFilePath.isEmpty) { + if (stateSchemaVersion == 3 && newSchemaFilePath.isEmpty) { throw new IllegalStateException("Schema file path is required for schema version 3") } @@ -110,13 +107,19 @@ class StateSchemaCompatibilityChecker( private[sql] def createSchemaFile( stateStoreColFamilySchema: List[StateStoreColFamilySchema], schemaWriter: SchemaWriter): Unit = { - val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) + val schemaFilePath = newSchemaFilePath match { + case Some(path) => + fm.mkdirs(path.getParent) + path + case None => schemaFileLocation + } + val outStream = fm.createAtomic(schemaFilePath, overwriteIfPossible = false) try { schemaWriter.write(stateStoreColFamilySchema, outStream) outStream.close() } catch { case e: Throwable => - logError(log"Fail to write schema file to ${MDC(LogKeys.PATH, schemaFileLocation)}", e) + logError(log"Fail to write schema file to ${MDC(LogKeys.PATH, schemaFilePath)}", e) outStream.cancel() throw e } @@ -208,7 +211,10 @@ object StateSchemaCompatibilityChecker { * @param stateSchemaVersion - version of the state schema to be used * @param extraOptions - any extra options to be passed for StateStoreConf creation * @param storeName - optional state store name - * @param schemaFilePath - optional schema file path + * @param oldSchemaFilePath - optional path to the old schema file. If not provided, will default + * to the schema file location + * @param newSchemaFilePath - optional path to the destination schema file. + * Needed for schema version 3 * @return - StateSchemaValidationResult containing the result of the schema validation */ def validateAndMaybeEvolveStateSchema( @@ -219,7 +225,8 @@ object StateSchemaCompatibilityChecker { stateSchemaVersion: Int, extraOptions: Map[String, String] = Map.empty, storeName: String = StateStoreId.DEFAULT_STORE_NAME, - schemaFilePath: Option[Path] = None): StateSchemaValidationResult = { + oldSchemaFilePath: Option[Path] = None, + newSchemaFilePath: Option[Path] = None): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means // in some collation we no longer be able to just compare the binary format of two // UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically" @@ -237,7 +244,7 @@ object StateSchemaCompatibilityChecker { val providerId = StateStoreProviderId(StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, 0, storeName), stateInfo.queryRunId) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - schemaFilePath = schemaFilePath) + oldSchemaFilePath = oldSchemaFilePath, newSchemaFilePath = newSchemaFilePath) // regardless of configuration, we check compatibility to at least write schema file // if necessary // if the format validation for value schema is disabled, we also disable the schema @@ -261,6 +268,10 @@ object StateSchemaCompatibilityChecker { if (storeConf.stateSchemaCheckEnabled && result.isDefined) { throw result.get } - StateSchemaValidationResult(evolvedSchema, checker.schemaFileLocation.toString) + val schemaFileLocation = newSchemaFilePath match { + case Some(path) => path.toString + case None => checker.schemaFileLocation.toString + } + StateSchemaValidationResult(evolvedSchema, schemaFileLocation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 14f67460763b1..43d75c4b4d137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -73,6 +74,12 @@ trait StatefulOperator extends SparkPlan { } } + def metadataFilePath(): Path = { + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + } + // Function used to record state schema for the first time and validate it against proposed // schema changes in the future. Runs as part of a planning rule on the driver. // Returns the schema file path for operators that write this to the metadata file, @@ -142,6 +149,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp */ def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + def operatorStateMetadataVersion: Int = 1 + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, @@ -157,6 +166,20 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } + /** * Get the progress made by this stateful operator after execution. This should be called in * the driver after this SparkPlan has been executed and metrics have been updated. @@ -190,7 +213,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** Metadata of this stateful operator and its states stores. */ - def operatorStateMetadata(): OperatorStateMetadata = { + def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -920,7 +944,8 @@ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = Array(StateStoreMetadataV1( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index dd8f7aab51dd0..65d32b474708a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -40,10 +40,11 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { operatorId: Int, expectedMetadata: OperatorStateMetadataV1): Unit = { val statePath = new Path(checkpointDir, s"state/$operatorId") - val operatorMetadata = new OperatorStateMetadataReader(statePath, hadoopConf).read() - .asInstanceOf[OperatorStateMetadataV1] - assert(operatorMetadata.operatorInfo == expectedMetadata.operatorInfo && - operatorMetadata.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo)) + val operatorMetadata = new OperatorStateMetadataV1Reader(statePath, hadoopConf).read() + .asInstanceOf[Option[OperatorStateMetadataV1]] + assert(operatorMetadata.isDefined) + assert(operatorMetadata.get.operatorInfo == expectedMetadata.operatorInfo && + operatorMetadata.get.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo)) } test("Serialize and deserialize stateful operator metadata") { @@ -52,14 +53,14 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val stateStoreInfo = (1 to 4).map(i => StateStoreMetadataV1(s"store$i", 1, 200)) val operatorInfo = OperatorInfoV1(1, "Join") val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) - new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) + new OperatorStateMetadataV1Writer(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) val df = spark.read.format("state-metadata").load(checkpointDir.toString) // Commit log is empty, there is no available batch id. - checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L), - Row(1, "Join", "store2", 200, -1L, -1L), - Row(1, "Join", "store3", 200, -1L, -1L), - Row(1, "Join", "store4", 200, -1L, -1L) + checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L, null), + Row(1, "Join", "store2", 200, -1L, -1L, null), + Row(1, "Join", "store3", 200, -1L, -1L, null), + Row(1, "Join", "store4", 200, -1L, -1L, null) )) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1), Row(1), Row(1), Row(1))) @@ -118,10 +119,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val df = spark.read.format("state-metadata") .load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L) + checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L, null) )) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0), Row(0), Row(0))) @@ -169,7 +170,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) val df = spark.read.format("state-metadata").load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) + checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L, null))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } } @@ -202,8 +203,8 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) val df = spark.read.format("state-metadata").load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), - Row(1, "stateStoreSave", "default", 5, 0L, 1L))) + checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L, null), + Row(1, "stateStoreSave", "default", 5, 0L, 1L, null))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index f5a5d1277d05d..38533825ece90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -275,7 +275,8 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val schemaFilePath = Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - schemaFilePath = schemaFilePath) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = schemaFilePath) checker.createSchemaFile(storeColFamilySchema, SchemaHelper.SchemaWriter.createSchemaWriter(stateSchemaVersion)) val stateSchema = checker.readSchemaFile() @@ -359,6 +360,14 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { } } + private def getNewSchemaPath(stateSchemaDir: Path, stateSchemaVersion: Int): Option[Path] = { + if (stateSchemaVersion == 3) { + Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) + } else { + None + } + } + private def verifyException( oldKeySchema: StructType, oldValueSchema: StructType, @@ -373,9 +382,9 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) + val stateSchemaDir = stateSchemaDirPath(stateInfo) Seq(2, 3).foreach { stateSchemaVersion => val schemaFilePath = if (stateSchemaVersion == 3) { - val stateSchemaDir = stateSchemaDirPath(stateInfo) Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) } else { None @@ -384,10 +393,13 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, oldKeySchema, oldValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) + val newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion) val result = Try( StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = newSchemaFilePath, + extraOptions = extraOptions) ).toEither.fold(Some(_), _ => None) val ex = if (result.isDefined) { @@ -399,7 +411,12 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + extraOptions = extraOptions, + oldSchemaFilePath = stateSchemaVersion match { + case 3 => newSchemaFilePath + case _ => None + }, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion)) } } @@ -433,9 +450,9 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) + val stateSchemaDir = stateSchemaDirPath(stateInfo) Seq(2, 3).foreach { stateSchemaVersion => val schemaFilePath = if (stateSchemaVersion == 3) { - val stateSchemaDir = stateSchemaDirPath(stateInfo) Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) } else { None @@ -446,14 +463,18 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), + extraOptions = extraOptions) val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, newKeySchema, newValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), + extraOptions = extraOptions) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 2e65748cb4673..d55a16a60eac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state._ @@ -63,6 +63,32 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +class RunningCountStatefulProcessorInt + extends StatefulProcessor[String, String, (String, String)] { + @transient protected var _countState: ValueState[Int] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + // Class to verify stateful processor usage with adding processing time timers class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { private def handleProcessingTimeBasedTimers( @@ -886,6 +912,77 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("transformWithState - verify that OperatorStateMetadataV2" + + " file is being written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq( + Row(0, "transformWithStateExec", "default", 5, 0L, 1L, + """{"timeMode":"NoTime","outputMode":"Update"}""") + )) + } + } + } + + test("test that invalid schema evolution fails query for column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreValueSchemaNotCompatible] { + (t: Throwable) => { + assert(t.getMessage.contains("Please check number and type of fields.")) + } + } + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {