Skip to content

Commit

Permalink
[SPARK-48849][SS] Create OperatorStateMetadataV2 for the TransformWit…
Browse files Browse the repository at this point in the history
…hStateExec operator

### What changes were proposed in this pull request?

Introducing the OperatorStateMetadataV2 format that integrates with the TransformWithStateExec operator. This is used to keep information about the TWS operator, will be used to enforce invariants in between query runs. Each OperatorStateMetadataV2 has a pointer to the StateSchemaV3 file for the corresponding operator.
Will introduce purging in this PR: #47286
### Why are the changes needed?

This is needed for State Metadata integration with the TransformWithState operator.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Added unit tests to StateStoreSuite and TransformWithStateSuite

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47445 from ericm-db/metadata-v2.

Authored-by: Eric Marnadi <eric.marnadi@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
ericm-db authored and HeartSaVioR committed Jul 25, 2024
1 parent cf95e75 commit 4999469
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1}
import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, StateStoreMetadataV1}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -46,6 +46,7 @@ case class StateMetadataTableEntry(
numPartitions: Int,
minBatchId: Long,
maxBatchId: Long,
operatorPropertiesJson: String,
numColsPrefixKey: Int) {
def toRow(): InternalRow = {
new GenericInternalRow(
Expand All @@ -55,6 +56,7 @@ case class StateMetadataTableEntry(
numPartitions,
minBatchId,
maxBatchId,
UTF8String.fromString(operatorPropertiesJson),
numColsPrefixKey))
}
}
Expand All @@ -68,6 +70,7 @@ object StateMetadataTableEntry {
.add("numPartitions", IntegerType)
.add("minBatchId", LongType)
.add("maxBatchId", LongType)
.add("operatorProperties", StringType)
}
}

Expand Down Expand Up @@ -188,29 +191,59 @@ class StateMetadataPartitionReader(
} else Array.empty
}

private def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
// Need this to be accessible from IncrementalExecution for the planning rule.
private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
val stateDir = new Path(checkpointLocation, "state")
val opIds = fileManager
.list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted
opIds.map { opId =>
new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read()
val operatorIdPath = new Path(stateDir, opId.toString)
// check if OperatorStateMetadataV2 path exists, if it does, read it
// otherwise, fall back to OperatorStateMetadataV1
val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath)
val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) {
2
} else {
1
}
OperatorStateMetadataReader.createReader(
operatorIdPath, hadoopConf, operatorStateMetadataVersion).read() match {
case Some(metadata) => metadata
case None => OperatorStateMetadataV1(OperatorInfoV1(opId, null),
Array(StateStoreMetadataV1(null, -1, -1)))
}
}
}

private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
allOperatorStateMetadata.flatMap { operatorStateMetadata =>
require(operatorStateMetadata.version == 1)
val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1]
operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId,
operatorStateMetadataV1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
stateStoreMetadata.numColsPrefixKey
)
require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2)
operatorStateMetadata match {
case v1: OperatorStateMetadataV1 =>
v1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v1.operatorInfo.operatorId,
v1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
null,
stateStoreMetadata.numColsPrefixKey
)
}
case v2: OperatorStateMetadataV2 =>
v2.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v2.operatorInfo.operatorId,
v2.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
v2.operatorPropertiesJson,
-1 // numColsPrefixKey is not available in OperatorStateMetadataV2
)
}
}
}
}
}.iterator
}.iterator
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -208,13 +208,16 @@ class IncrementalExecution(
}
val schemaValidationResult = statefulOp.
validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion)
val stateSchemaPaths = schemaValidationResult.map(_.schemaPath)
// write out the state schema paths to the metadata file
statefulOp match {
case stateStoreWriter: StateStoreWriter =>
val metadata = stateStoreWriter.operatorStateMetadata()
// TODO: [SPARK-48849] Populate metadata with stateSchemaPaths if metadata version is v2
val metadataWriter = new OperatorStateMetadataWriter(new Path(
checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
case ssw: StateStoreWriter =>
val metadata = ssw.operatorStateMetadata(stateSchemaPaths)
val metadataWriter = OperatorStateMetadataWriter.createWriter(
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
hadoopConf,
ssw.operatorStateMetadataVersion,
Some(currentBatchId))
metadataWriter.write(metadata)
case _ =>
}
Expand Down Expand Up @@ -456,8 +459,12 @@ class IncrementalExecution(
val reader = new StateMetadataPartitionReader(
new Path(checkpointLocation).getParent.toString,
new SerializableConfiguration(hadoopConf))
ret = reader.stateMetadata.map { metadataTableEntry =>
metadataTableEntry.operatorId -> metadataTableEntry.operatorName
val opMetadataList = reader.allOperatorStateMetadata
ret = opMetadataList.map {
case OperatorStateMetadataV1(operatorInfo, _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
case OperatorStateMetadataV2(operatorInfo, _, _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
}.toMap
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec(
private val stateStoreNames =
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)

override def operatorStateMetadata(): OperatorStateMetadata = {
override def operatorStateMetadata(
stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
val stateStoreInfo =
stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.JString
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -96,6 +100,8 @@ case class TransformWithStateExec(
}
}

override def operatorStateMetadataVersion: Int = 2

/**
* We initialize this processor handle in the driver to run the init function
* and fetch the schemas of the state variables initialized in this processor.
Expand Down Expand Up @@ -382,12 +388,47 @@ case class TransformWithStateExec(
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
assert(stateSchemaVersion >= 3)
val newColumnFamilySchemas = getColFamilySchemas()
val newSchemas = getColFamilySchemas()
val stateSchemaDir = stateSchemaDirPath()
val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion,
schemaFilePath = Some(stateSchemaFilePath)))
val newStateSchemaFilePath =
new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
val metadataPath = new Path(getStateInfo.checkpointLocation, s"${getStateInfo.operatorId}")
val metadataReader = OperatorStateMetadataReader.createReader(
metadataPath, hadoopConf, operatorStateMetadataVersion)
val operatorStateMetadata = metadataReader.read()
val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match {
case Some(metadata) =>
metadata match {
case v2: OperatorStateMetadataV2 =>
Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath))
case _ => None
}
case None => None
}
List(StateSchemaCompatibilityChecker.
validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
newSchemas.values.toList, session.sessionState, stateSchemaVersion,
storeName = StateStoreId.DEFAULT_STORE_NAME,
oldSchemaFilePath = oldStateSchemaFilePath,
newSchemaFilePath = Some(newStateSchemaFilePath)))
}

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(
stateSchemaPaths: List[String]): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
// stateSchemaFilePath should be populated at this point
val stateStoreInfo =
Array(StateStoreMetadataV2(
StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head))

val operatorPropertiesJson: JValue =
("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString))

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
}

private def stateSchemaDirPath(): Path = {
Expand Down
Loading

0 comments on commit 4999469

Please sign in to comment.