Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing StateSchemaV3 for the TransformWithState operator with HDFSMetadataLog #11

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.hadoop.fs.Path
import org.json4s.JsonAST.JValue

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR}
Expand All @@ -37,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -187,6 +188,52 @@ class IncrementalExecution(
}
}

def writeSchemaAndMetadataFiles(
stateSchemaV3File: StateSchemaV3File,
operatorStateMetadataLog: OperatorStateMetadataLog,
stateSchema: JValue,
operatorStateMetadata: OperatorStateMetadata): Unit = {
operatorStateMetadataLog.purgeAfter(currentBatchId - 1)
if (!stateSchemaV3File.add(currentBatchId, stateSchema)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
}
if (!operatorStateMetadataLog.add(currentBatchId, operatorStateMetadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
}
}

object PopulateSchemaV3Rule extends SparkPlanPartialRule with Logging {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case tws: TransformWithStateExec if isFirstBatch =>
val stateSchemaV3File = new StateSchemaV3File(
hadoopConf, tws.stateSchemaFilePath().toString)
val operatorStateMetadataLog = new OperatorStateMetadataLog(
hadoopConf,
tws.metadataFilePath().toString
)
stateSchemaV3File.getLatest() match {
case Some((_, oldSchema)) =>
val newSchema = tws.getSchema()
tws.compareSchemas(oldSchema, newSchema)
writeSchemaAndMetadataFiles(
stateSchemaV3File = stateSchemaV3File,
operatorStateMetadataLog = operatorStateMetadataLog,
stateSchema = newSchema,
operatorStateMetadata = tws.operatorStateMetadata()
)
tws.copy(columnFamilyJValue = Some(oldSchema))
case None =>
writeSchemaAndMetadataFiles(
stateSchemaV3File = stateSchemaV3File,
operatorStateMetadataLog = operatorStateMetadataLog,
stateSchema = tws.getSchema(),
operatorStateMetadata = tws.operatorStateMetadata()
)
tws
}
}
}

object StateOpIdRule extends SparkPlanPartialRule {
override val rule: PartialFunction[SparkPlan, SparkPlan] = {
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
Expand Down Expand Up @@ -455,15 +502,16 @@ class IncrementalExecution(

override def apply(plan: SparkPlan): SparkPlan = {
val planWithStateOpId = plan transform composedRule
val planWithSchema = planWithStateOpId transform PopulateSchemaV3Rule.rule
// Need to check before write to metadata because we need to detect add operator
// Only check when streaming is restarting and is first batch
if (isFirstBatch && currentBatchId != 0) {
checkOperatorValidWithMetadata(planWithStateOpId)
checkOperatorValidWithMetadata(planWithSchema)
}
// The rule doesn't change the plan but cause the side effect that metadata is written
// in the checkpoint directory of stateful operator.
simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
simulateWatermarkPropagation(planWithSchema)
planWithSchema transform WatermarkPropagationRule.rule
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState

object ListStateImpl {
def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
true)
}
}
/**
* Provides concrete implementation for list of values associated with a state variable
* used in the streaming transformWithState operator.
Expand All @@ -44,8 +54,7 @@ class ListStateImpl[S](

private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName)

store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
store.createColFamilyIfAbsent(ListStateImpl.columnFamilySchema(stateName))

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.util.NextIterator

object ListStateImplWithTTL {
def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
true)
}
}
/**
* Class that provides a concrete implementation for a list state state associated with state
* variables (with ttl expiration support) used in the streaming transformWithState operator.
Expand Down Expand Up @@ -55,8 +65,7 @@ class ListStateImplWithTTL[S](
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
store.createColFamilyIfAbsent(ListStateImplWithTTL.columnFamilySchema(stateName))
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.{BinaryType, StructType}

object MapStateImpl {
def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
COMPOSITE_KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)
}
}

class MapStateImpl[K, V](
store: StateStore,
Expand All @@ -30,18 +40,11 @@ class MapStateImpl[K, V](
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType =
new StructType()
.add("key", BinaryType)
.add("userKey", BinaryType)
private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
private val keySerializer = keyExprEnc.createSerializer()
private val stateTypesEncoder = new CompositeKeyStateEncoder(
keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName)
keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName)

store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow,
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
store.createColFamilyIfAbsent(MapStateImpl.columnFamilySchema(stateName))

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

object MapStateImplWithTTL {
def columnFamilySchema(stateName: String): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
stateName,
COMPOSITE_KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)
}
}
/**
* Class that provides a concrete implementation for map state associated with state
* variables (with ttl expiration support) used in the streaming transformWithState operator.
Expand Down Expand Up @@ -58,8 +67,7 @@ class MapStateImplWithTTL[K, V](
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
store.createColFamilyIfAbsent(MapStateImplWithTTL.columnFamilySchema(stateName))
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,23 +903,6 @@ class MicroBatchExecution(
if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
val shouldWriteMetadatas = execCtx.previousContext match {
case Some(prevCtx)
if prevCtx.executionPlan.runId == execCtx.executionPlan.runId =>
false
case _ => true
}
if (shouldWriteMetadatas) {
execCtx.executionPlan.executedPlan.collect {
case s: StateStoreWriter =>
val metadata = s.operatorStateMetadata()
val id = metadata.operatorInfo.operatorId
val metadataFile = operatorStateMetadataLogs(id)
if (!metadataFile.add(execCtx.batchId, metadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
}
}
committedOffsets ++= execCtx.endOffsets
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,32 @@ class OperatorStateMetadataLog(
case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader)
}
}


/**
* Store the metadata for the specified batchId and return `true` if successful. If the batchId's
* metadata has already been stored, this method will return `false`.
*/
override def add(batchId: Long, metadata: OperatorStateMetadata): Boolean = {
require(metadata != null, "'null' metadata cannot written to a metadata log")
val batchMetadataFile = batchIdToPath(batchId)
if (fileManager.exists(batchMetadataFile)) {
fileManager.delete(batchMetadataFile)
}
val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) }
if (metadataCacheEnabled && res) batchCache.put(batchId, metadata)
res
}

override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = {
val batchMetadataFile = batchIdToPath(batchId)

if (metadataCacheEnabled && batchCache.containsKey(batchId)) {
false
} else {
write(batchMetadataFile, fn)
true
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import java.io.{InputStream, OutputStream, StringReader}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream}
import org.json4s.JValue
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

class StateSchemaV3File(
hadoopConf: Configuration,
path: String,
metadataCacheEnabled: Boolean = false)
extends HDFSMetadataLog[JValue](hadoopConf, path, metadataCacheEnabled) {

final val MAX_UTF_CHUNK_SIZE = 65535
def this(sparkSession: SparkSession, path: String) = {
this(
sparkSession.sessionState.newHadoopConf(),
path,
metadataCacheEnabled = sparkSession.sessionState.conf.getConf(
SQLConf.STREAMING_METADATA_CACHE_ENABLED)
)
}

override protected def serialize(schema: JValue, out: OutputStream): Unit = {
val json = compact(render(schema))
val buf = new Array[Char](MAX_UTF_CHUNK_SIZE)

val outputStream = out.asInstanceOf[FSDataOutputStream]
// DataOutputStream.writeUTF can't write a string at once
// if the size exceeds 65535 (2^16 - 1) bytes.
// Each metadata consists of multiple chunks in schema version 3.
try {
val numMetadataChunks = (json.length - 1) / MAX_UTF_CHUNK_SIZE + 1
val metadataStringReader = new StringReader(json)
outputStream.writeInt(numMetadataChunks)
(0 until numMetadataChunks).foreach { _ =>
val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE)
outputStream.writeUTF(new String(buf, 0, numRead))
}
outputStream.close()
} catch {
case e: Throwable =>
throw e
}
}

override protected def deserialize(in: InputStream): JValue = {
val buf = new StringBuilder
val inputStream = in.asInstanceOf[FSDataInputStream]
val numKeyChunks = inputStream.readInt()
(0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF()))
val json = buf.toString()
JsonMethods.parse(json)
}

override def add(batchId: Long, metadata: JValue): Boolean = {
require(metadata != null, "'null' metadata cannot written to a metadata log")
val batchMetadataFile = batchIdToPath(batchId)
if (fileManager.exists(batchMetadataFile)) {
fileManager.delete(batchMetadataFile)
}
val res = addNewBatchByStream(batchId) { output => serialize(metadata, output) }
if (metadataCacheEnabled && res) batchCache.put(batchId, metadata)
res
}

override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): Boolean = {
val batchMetadataFile = batchIdToPath(batchId)

if (metadataCacheEnabled && batchCache.containsKey(batchId)) {
false
} else {
write(batchMetadataFile, fn)
true
}
}
}
Loading
Loading