diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 42015a5bd29ee..1b6b17cf11d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -21,6 +21,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR} @@ -37,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -187,6 +188,52 @@ class IncrementalExecution( } } + def writeSchemaAndMetadataFiles( + stateSchemaV3File: StateSchemaV3File, + operatorStateMetadataLog: OperatorStateMetadataLog, + stateSchema: JValue, + operatorStateMetadata: OperatorStateMetadata): Unit = { + operatorStateMetadataLog.purgeAfter(currentBatchId - 1) + if (!stateSchemaV3File.add(currentBatchId, stateSchema)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId) + } + if (!operatorStateMetadataLog.add(currentBatchId, operatorStateMetadata)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId) + } + } + + object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging { + override val rule: PartialFunction[SparkPlan, SparkPlan] = { + case tws: TransformWithStateExec if isFirstBatch => + val stateSchemaV3File = new StateSchemaV3File( + hadoopConf, tws.stateSchemaFilePath().toString) + val operatorStateMetadataLog = new OperatorStateMetadataLog( + hadoopConf, + tws.metadataFilePath().toString + ) + stateSchemaV3File.getLatest() match { + case Some((_, oldSchema)) => + val newSchema = tws.getSchema() + tws.compareSchemas(oldSchema, newSchema) + writeSchemaAndMetadataFiles( + stateSchemaV3File = stateSchemaV3File, + operatorStateMetadataLog = operatorStateMetadataLog, + stateSchema = newSchema, + operatorStateMetadata = tws.operatorStateMetadata() + ) + tws.copy(columnFamilyJValue = Some(oldSchema)) + case None => + writeSchemaAndMetadataFiles( + stateSchemaV3File = stateSchemaV3File, + operatorStateMetadataLog = operatorStateMetadataLog, + stateSchema = tws.getSchema(), + operatorStateMetadata = tws.operatorStateMetadata() + ) + tws + } + } + } + object StateOpIdRule extends SparkPlanPartialRule { override val rule: PartialFunction[SparkPlan, SparkPlan] = { case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, @@ -455,15 +502,16 @@ class IncrementalExecution( override def apply(plan: SparkPlan): SparkPlan = { val planWithStateOpId = plan transform composedRule + val planWithSchema = planWithStateOpId transform PopulateSchemaV3Rule.rule // Need to check before write to metadata because we need to detect add operator // Only check when streaming is restarting and is first batch if (isFirstBatch && currentBatchId != 0) { - checkOperatorValidWithMetadata(planWithStateOpId) + checkOperatorValidWithMetadata(planWithSchema) } // The rule doesn't change the plan but cause the side effect that metadata is written // in the checkpoint directory of stateful operator. - simulateWatermarkPropagation(planWithStateOpId) - planWithStateOpId transform WatermarkPropagationRule.rule + simulateWatermarkPropagation(planWithSchema) + planWithSchema transform WatermarkPropagationRule.rule } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 56c9d2664d9e2..429464a5467b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -20,9 +20,19 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState +object ListStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + true) + } +} /** * Provides concrete implementation for list of values associated with a state variable * used in the streaming transformWithState operator. @@ -44,8 +54,7 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + store.createColFamilyIfAbsent(ListStateImpl.columnFamilySchema(stateName)) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index dc72f8bcd5600..969ad8a889fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,10 +19,20 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.util.NextIterator +object ListStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + true) + } +} /** * Class that provides a concrete implementation for a list state state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. @@ -55,8 +65,7 @@ class ListStateImplWithTTL[S]( initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + store.createColFamilyIfAbsent(ListStateImplWithTTL.columnFamilySchema(stateName)) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index c58f32ed756db..0d3a0be5cf5e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -19,9 +19,19 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState -import org.apache.spark.sql.types.{BinaryType, StructType} + +object MapStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + COMPOSITE_KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + } +} class MapStateImpl[K, V]( store: StateStore, @@ -30,18 +40,11 @@ class MapStateImpl[K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V]) extends MapState[K, V] with Logging { - // Pack grouping key and user key together as a prefixed composite key - private val schemaForCompositeKeyRow: StructType = - new StructType() - .add("key", BinaryType) - .add("userKey", BinaryType) - private val schemaForValueRow: StructType = new StructType().add("value", BinaryType) private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = new CompositeKeyStateEncoder( - keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName) + keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow, - PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1)) + store.createColFamilyIfAbsent(MapStateImpl.columnFamilySchema(stateName)) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 2ab06f36dd5f7..cb99ccf248f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,10 +20,19 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator +object MapStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + COMPOSITE_KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + } +} /** * Class that provides a concrete implementation for map state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. @@ -58,8 +67,7 @@ class MapStateImplWithTTL[K, V]( initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1)) + store.createColFamilyIfAbsent(MapStateImplWithTTL.columnFamilySchema(stateName)) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f85adf8c34363..20dfcd7c7fd8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -903,23 +903,6 @@ class MicroBatchExecution( if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } - val shouldWriteMetadatas = execCtx.previousContext match { - case Some(prevCtx) - if prevCtx.executionPlan.runId == execCtx.executionPlan.runId => - false - case _ => true - } - if (shouldWriteMetadatas) { - execCtx.executionPlan.executedPlan.collect { - case s: StateStoreWriter => - val metadata = s.operatorStateMetadata() - val id = metadata.operatorInfo.operatorId - val metadataFile = operatorStateMetadataLogs(id) - if (!metadataFile.add(execCtx.batchId, metadata)) { - throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) - } - } - } } committedOffsets ++= execCtx.endOffsets } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala index f77875279384f..8bbd589a12885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OperatorStateMetadataLog.scala @@ -66,4 +66,32 @@ class OperatorStateMetadataLog( case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader) } } + + + /** + * Store the metadata for the specified batchId and return `true` if successful. If the batchId's + * metadata has already been stored, this method will return `false`. + */ + override def add(batchId: Long, metadata: OperatorStateMetadata): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") + val batchMetadataFile = batchIdToPath(batchId) + if (fileManager.exists(batchMetadataFile)) { + fileManager.delete(batchMetadataFile) + } + val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) } + if (metadataCacheEnabled && res) batchCache.put(batchId, metadata) + res + } + + override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = { + val batchMetadataFile = batchIdToPath(batchId) + + if (metadataCacheEnabled && batchCache.containsKey(batchId)) { + false + } else { + write(batchMetadataFile, fn) + true + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala new file mode 100644 index 0000000000000..82bab9a5301f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, OutputStream, StringReader} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} +import org.json4s.JValue +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +class StateSchemaV3File( + hadoopConf: Configuration, + path: String, + metadataCacheEnabled: Boolean = false) + extends HDFSMetadataLog[JValue](hadoopConf, path, metadataCacheEnabled) { + + final val MAX_UTF_CHUNK_SIZE = 65535 + def this(sparkSession: SparkSession, path: String) = { + this( + sparkSession.sessionState.newHadoopConf(), + path, + metadataCacheEnabled = sparkSession.sessionState.conf.getConf( + SQLConf.STREAMING_METADATA_CACHE_ENABLED) + ) + } + + override protected def serialize(schema: JValue, out: OutputStream): Unit = { + val json = compact(render(schema)) + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + + val outputStream = out.asInstanceOf[FSDataOutputStream] + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // Each metadata consists of multiple chunks in schema version 3. + try { + val numMetadataChunks = (json.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(json) + outputStream.writeInt(numMetadataChunks) + (0 until numMetadataChunks).foreach { _ => + val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + outputStream.close() + } catch { + case e: Throwable => + throw e + } + } + + override protected def deserialize(in: InputStream): JValue = { + val buf = new StringBuilder + val inputStream = in.asInstanceOf[FSDataInputStream] + val numKeyChunks = inputStream.readInt() + (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) + val json = buf.toString() + JsonMethods.parse(json) + } + + override def add(batchId: Long, metadata: JValue): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") + val batchMetadataFile = batchIdToPath(batchId) + if (fileManager.exists(batchMetadataFile)) { + fileManager.delete(batchMetadataFile) + } + val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) } + if (metadataCacheEnabled && res) batchCache.put(batchId, metadata) + res + } + + override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = { + val batchMetadataFile = batchIdToPath(batchId) + + if (metadataCacheEnabled && batchCache.containsKey(batchId)) { + false + } else { + write(batchMetadataFile, fn) + true + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index e1d578fb2e5ca..b14eea3e5feb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -85,7 +85,8 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + existingColumnFamilies: Map[String, ColumnFamilySchema] = Map.empty) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ @@ -98,6 +99,9 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() + private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] = + new util.ArrayList[ColumnFamilySchema]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" private def buildQueryInfo(): QueryInfo = { @@ -139,6 +143,8 @@ class StatefulProcessorHandleImpl( new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -158,6 +164,8 @@ class StatefulProcessorHandleImpl( valueStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) + val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -247,9 +255,11 @@ class StatefulProcessorHandleImpl( * @param stateName - name of the state variable */ override def deleteIfExists(stateName: String): Unit = { - verifyStateVarOperations("delete_if_exists") - if (store.get.removeColFamilyIfExists(stateName)) { - incrementMetric("numDeletedStateVars") + if (store.isDefined) { + verifyStateVarOperations("delete_if_exists") + if (store.get.removeColFamilyIfExists(stateName)) { + incrementMetric("numDeletedStateVars") + } } } @@ -261,6 +271,8 @@ class StatefulProcessorHandleImpl( new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ListState, false)) + val colFamilySchema = ListStateImpl.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -296,6 +308,8 @@ class StatefulProcessorHandleImpl( listStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, ListState, true)) + val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -311,6 +325,8 @@ class StatefulProcessorHandleImpl( new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) case None => stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) + val colFamilySchema = MapStateImpl.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } @@ -331,6 +347,8 @@ class StatefulProcessorHandleImpl( mapStateWithTTL case None => stateVariables.add(new StateVariableInfo(stateName, MapState, true)) + val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName) + columnFamilySchemas.add(colFamilySchema) null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 605f536122f92..63d009ed928c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -244,6 +244,10 @@ abstract class StreamExecution( populateOperatorStateMetadatas(getLatestExecutionContext().executionPlan.executedPlan) } + lazy val stateSchemaLogs: Map[Long, StateSchemaV3File] = { + populateStateSchemaFiles(getLatestExecutionContext().executionPlan.executedPlan) + } + private def populateOperatorStateMetadatas( plan: SparkPlan): Map[Long, OperatorStateMetadataLog] = { plan.flatMap { @@ -256,6 +260,18 @@ abstract class StreamExecution( }.toMap } + private def populateStateSchemaFiles( + plan: SparkPlan): Map[Long, StateSchemaV3File] = { + plan.flatMap { + case s: StateStoreWriter => s.stateInfo.map { info => + val schemaFilePath = s.stateSchemaFilePath() + info.operatorId -> new StateSchemaV3File(sparkSession, + schemaFilePath.toString) + } + case _ => Seq.empty + }.toMap + } + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING @@ -699,6 +715,8 @@ abstract class StreamExecution( protected def purgeOldest(): Unit = { operatorStateMetadataLogs.foreach( _._2.purgeOldest(minLogEntriesToMaintain)) + stateSchemaLogs.foreach( + _._2.purgeOldest(minLogEntriesToMaintain)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index d085581173559..22528d6f7068f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -81,7 +81,8 @@ case class TransformWithStateExec( initialStateGroupingAttrs: Seq[Attribute], initialStateDataAttrs: Seq[Attribute], initialStateDeserializer: Expression, - initialState: SparkPlan) + initialState: SparkPlan, + columnFamilyJValue: Option[JValue] = None) extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { val operatorProperties: util.Map[String, JValue] = @@ -91,7 +92,6 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" - /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo @@ -99,14 +99,46 @@ case class TransformWithStateExec( val stateStoreInfo = Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions)) + val driverProcessorHandle = getDriverProcessorHandle + val stateVariables = JArray(driverProcessorHandle.stateVariables. + asScala.map(_.jsonValue).toList) + + closeProcessorHandle(driverProcessorHandle) val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ ("outputMode" -> JString(outputMode.toString)) ~ - ("stateVariables" -> operatorProperties.get("stateVariables")) + ("stateVariables" -> stateVariables) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } + def getSchema(): JValue = { + val driverProcessorHandle = getDriverProcessorHandle + val columnFamilySchemas = JArray(driverProcessorHandle. + columnFamilySchemas.asScala.map(_.jsonValue).toList) + closeProcessorHandle(driverProcessorHandle) + columnFamilySchemas + } + + def compareSchemas(oldSchema: JValue, newSchema: JValue): Unit = { + val oldColumnFamilies = ColumnFamilySchemaV1.fromJValue(oldSchema) + val newColumnFamilies = ColumnFamilySchemaV1.fromJValue(newSchema).map { + case c1: ColumnFamilySchemaV1 => + c1.columnFamilyName -> c1 + }.toMap + + oldColumnFamilies.foreach { + case oldColumnFamily: ColumnFamilySchemaV1 => + newColumnFamilies.get(oldColumnFamily.columnFamilyName) match { + case Some(newColumnFamily) if oldColumnFamily.json != newColumnFamily.json => + throw new RuntimeException( + s"State variable with name ${newColumnFamily.columnFamilyName}" + + s" already exists with different schema.") + case _ => // do nothing + } + } + } + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { if (timeMode == ProcessingTime) { // TODO: check if we can return true only if actual timers are registered, or there is @@ -368,22 +400,26 @@ case class TransformWithStateExec( ) } - override protected def doExecute(): RDD[InternalRow] = { - metrics // force lazy init at driver - - validateTimeMode() - + protected def getDriverProcessorHandle: StatefulProcessorHandleImpl = { val driverProcessorHandle = new StatefulProcessorHandleImpl( None, getStateInfo.queryRunId, keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) - driverProcessorHandle.setHandleState(StatefulProcessorHandleState.PRE_INIT) statefulProcessor.setHandle(driverProcessorHandle) statefulProcessor.init(outputMode, timeMode) - operatorProperties.put("stateVariables", JArray(driverProcessorHandle.stateVariables. - asScala.map(_.jsonValue).toList)) + driverProcessorHandle + } + + protected def closeProcessorHandle(processorHandle: StatefulProcessorHandleImpl): Unit = { + statefulProcessor.close() statefulProcessor.setHandle(null) - driverProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + validateTimeMode() if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index d916011245c00..ea32ccf29bfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState /** @@ -32,6 +32,17 @@ import org.apache.spark.sql.streaming.ValueState * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored */ +object ValueStateImpl { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false) + } +} + class ValueStateImpl[S]( store: StateStore, stateName: String, @@ -45,8 +56,7 @@ class ValueStateImpl[S]( initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(ValueStateImpl.columnFamilySchema(stateName)) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 0ed5a6f29a984..428bfa1d75776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,9 +19,20 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +object ValueStateImplWithTTL { + def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = { + new ColumnFamilySchemaV1( + stateName, + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), + false) + } +} + /** * Class that provides a concrete implementation for a single value state associated with state * variables (with ttl expiration support) used in the streaming transformWithState operator. @@ -52,8 +63,7 @@ class ValueStateImplWithTTL[S]( initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(ValueStateImplWithTTL.columnFamilySchema(stateName)) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 543cd74c489d0..2ba710d1b05cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -130,6 +130,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) } + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) + } + // Multiple col families are not supported with HDFSBackedStateStoreProvider. Throw an exception // if the user tries to use a non-default col family. private def assertUseOfDefaultColFamily(colFamilyName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 36bfb34edc412..3bc1442371322 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -25,13 +25,10 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.JValue import org.json4s.jackson.Serialization -import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} -import org.apache.spark.util.AccumulatorV2 /** * Metadata for a state store instance. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e7fc9f56dd9eb..c52e69e16d581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -65,6 +65,16 @@ private[sql] class RocksDBStateStoreProvider RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) } + override def createColFamilyIfAbsent( + colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + createColFamilyIfAbsent( + colFamilyMetadata.columnFamilyName, + colFamilyMetadata.keySchema, + colFamilyMetadata.valueSchema, + colFamilyMetadata.keyStateEncoderSpec, + colFamilyMetadata.multipleValuesPerKey) + } + override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = { verify(key != null, "Key cannot be null") val kvEncoder = keyValueEncoderMap.get(colFamilyName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 2eef3d9fc22ed..71776005ce24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -17,14 +17,71 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.StringReader +import java.io.{OutputStream, StringReader} -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.json4s.{DefaultFormats, JsonAST} +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.sql.execution.streaming.MetadataVersionUtil +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +sealed trait ColumnFamilySchema extends Serializable { + def jsonValue: JsonAST.JObject + + def json: String +} + +case class ColumnFamilySchemaV1( + val columnFamilyName: String, + val keySchema: StructType, + val valueSchema: StructType, + val keyStateEncoderSpec: KeyStateEncoderSpec, + val multipleValuesPerKey: Boolean) extends ColumnFamilySchema { + def jsonValue: JsonAST.JObject = { + ("columnFamilyName" -> JString(columnFamilyName)) ~ + ("keySchema" -> keySchema.json) ~ + ("valueSchema" -> valueSchema.json) ~ + ("keyStateEncoderSpec" -> keyStateEncoderSpec.jsonValue) ~ + ("multipleValuesPerKey" -> JBool(multipleValuesPerKey)) + } + + def json: String = { + compact(render(jsonValue)) + } +} + +object ColumnFamilySchemaV1 { + def fromJson(json: List[Map[String, Any]]): List[ColumnFamilySchema] = { + assert(json.isInstanceOf[List[_]]) + + json.map { colFamilyMap => + new ColumnFamilySchemaV1( + colFamilyMap("columnFamilyName").asInstanceOf[String], + StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]), + StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]), + KeyStateEncoderSpec.fromJson(colFamilyMap("keyStateEncoderSpec") + .asInstanceOf[Map[String, Any]]), + colFamilyMap("multipleValuesPerKey").asInstanceOf[Boolean] + ) + } + } + + def fromJValue(jValue: JValue): List[ColumnFamilySchema] = { + implicit val formats: DefaultFormats.type = DefaultFormats + val deserializedList: List[Any] = jValue.extract[List[Any]] + assert(deserializedList.isInstanceOf[List[_]], + s"Expected List but got ${deserializedList.getClass}") + val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas) + } +} + /** * Helper classes for reading/writing state schema. */ @@ -68,6 +125,34 @@ object SchemaHelper { } } + class SchemaV3Reader( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { + + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + def read: List[ColumnFamilySchema] = { + if (!fm.exists(schemaFilePath)) { + return List.empty + } + val buf = new StringBuilder + val inputStream = fm.open(schemaFilePath) + val numKeyChunks = inputStream.readInt() + (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) + val json = buf.toString() + val parsedJson = JsonMethods.parse(json) + + implicit val formats = DefaultFormats + val deserializedList: List[Any] = parsedJson.extract[List[Any]] + assert(deserializedList.isInstanceOf[List[_]], + s"Expected List but got ${deserializedList.getClass}") + val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas) + } + } + trait SchemaWriter { val version: Int @@ -144,4 +229,55 @@ object SchemaHelper { } } } + + object SchemaV3Writer { + def getSchemaFilePath(stateCheckpointPath: Path): Path = { + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + + def serialize(out: OutputStream, schema: List[ColumnFamilySchema]): Unit = { + val json = schema.map(_.json) + out.write(compact(render(json)).getBytes("UTF-8")) + } + } + /** + * Schema writer for schema version 3. Because this writer writes out ColFamilyMetadatas + * instead of key and value schemas, it is not compatible with the SchemaWriter interface. + */ + class SchemaV3Writer( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { + val version: Int = 3 + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + + // 2^16 - 1 bytes + final val MAX_UTF_CHUNK_SIZE = 65535 + def writeSchema(metadatasJson: String): Unit = { + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + + if (fm.exists(schemaFilePath)) return + + fm.mkdirs(schemaFilePath.getParent) + val outputStream = fm.createAtomic(schemaFilePath, overwriteIfPossible = false) + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // Each metadata consists of multiple chunks in schema version 3. + try { + val numMetadataChunks = (metadatasJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(metadatasJson) + outputStream.writeInt(numMetadataChunks) + (0 until numMetadataChunks).foreach { _ => + val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + outputStream.close() + } catch { + case e: Throwable => + outputStream.cancel() + throw e + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8c2170abe3116..7e0c9e4868899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -28,6 +28,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.{JInt, JsonAST, JString} +import org.json4s.JsonAST.JObject +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{SparkContext, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, LogKeys, MDC} @@ -133,6 +137,10 @@ trait StateStore extends ReadStateStore { useMultipleValuesPerKey: Boolean = false, isInternal: Boolean = false): Unit + def createColFamilyIfAbsent( + colFamilyMetadata: ColumnFamilySchemaV1 + ): Unit + /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows * in the params can be reused, and must make copies of the data as needed for persistence. @@ -289,9 +297,35 @@ class InvalidUnsafeRowException(error: String) "among restart. For the first case, you can try to restart the application without " + s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null) -sealed trait KeyStateEncoderSpec +sealed trait KeyStateEncoderSpec { + def jsonValue: JsonAST.JObject + def json: String = compact(render(jsonValue)) +} -case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec +object KeyStateEncoderSpec { + def fromJson(m: Map[String, Any]): KeyStateEncoderSpec = { + // match on type + val keySchema = StructType.fromString(m("keySchema").asInstanceOf[String]) + m("keyStateEncoderType").asInstanceOf[String] match { + case "NoPrefixKeyStateEncoderSpec" => + NoPrefixKeyStateEncoderSpec(keySchema) + case "RangeKeyScanStateEncoderSpec" => + val orderingOrdinals = m("orderingOrdinals"). + asInstanceOf[List[_]].map(_.asInstanceOf[Int]) + RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) + case "PrefixKeyScanStateEncoderSpec" => + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int] + PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + } + } +} + +case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) + } +} case class PrefixKeyScanStateEncoderSpec( keySchema: StructType, @@ -299,6 +333,12 @@ case class PrefixKeyScanStateEncoderSpec( if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString) } + + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) ~ + ("numColsPrefixKey" -> JInt(numColsPrefixKey)) + } } /** Encodes rows so that they can be range-scanned based on orderingOrdinals */ @@ -308,6 +348,12 @@ case class RangeKeyScanStateEncoderSpec( if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) } + + override def jsonValue: JObject = { + ("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~ + ("keySchema" -> JString(keySchema.json)) ~ + ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index f7c6ffb8fdc47..b6c8a9970021a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -78,6 +78,21 @@ trait StatefulOperator extends SparkPlan { new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") } + + // /state//0//_metadata/schema + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 6a476635a6dbe..fa678e4a4f78f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -40,6 +40,10 @@ class MemoryStateStore extends StateStore() { throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider") } + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { + throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") + } + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index e283ba5c11f34..513bcd7540460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.time.Duration import java.util.UUID import org.apache.spark.SparkRuntimeException @@ -35,6 +36,36 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorWithTTL(ttlConfig: TTLConfig) + extends StatefulProcessor[String, String, (String, String)] + with Logging { + + @transient private var _countState: ValueStateImplWithTTL[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle + .getValueState("countState", Encoders.scalaLong, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[Long]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -854,6 +885,58 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("transformWithState - verify that query with ttl enabled after restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { chkptDir => + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithTTL(TTLConfig(Duration.ofMinutes(1))), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + // verify that query with ttl enabled after restart fails + testStream(result2, OutputMode.Append())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + Execute { q => + val e = intercept[Exception] { + q.processAllAvailable() + } + assert(e.getMessage.contains("State variable with name" + + " countState already exists with different schema")) + } + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {