Skip to content

Commit

Permalink
stage
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 24, 2024
1 parent cf84d50 commit 6922595
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
new HDFSBackedReadStateStore(endVersion, newMap)
}

override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
new HDFSBackedStateStoreCDCReader(fm, baseDir, startVersion, endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
}

private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
try {
if (version < 0) {
Expand Down Expand Up @@ -338,12 +344,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
new HDFSBackedStateStoreCDCReader(fm, baseDir, startVersion, endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
}

// Run bunch of validations specific to HDFSBackedStateStoreProvider
private def runValidation(
useColumnFamilies: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -354,6 +356,19 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
val statePath = stateStoreId.storeCheckpointLocation()
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
new RocksDBStateStoreCDCReader(
CheckpointFileManager.create(statePath, hadoopConf),
statePath,
startVersion,
endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema,
valueSchema)
}

override def doMaintenance(): Unit = {
try {
rocksDB.doMaintenance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@ import org.apache.spark.util.NextIterator
*/
abstract class StateStoreCDCReader(
fm: CheckpointFileManager,
// fileToRead: Path,
stateLocation: Path,
startVersion: Long,
endVersion: Long,
compressionCodec: CompressionCodec,
keySchema: StructType,
valueSchema: StructType)
compressionCodec: CompressionCodec)
extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging {

class ChangeLogFileIterator(
Expand Down Expand Up @@ -89,7 +86,7 @@ class HDFSBackedStateStoreCDCReader(
valueSchema: StructType
)
extends StateStoreCDCReader(
fm, stateLocation, startVersion, endVersion, compressionCodec, keySchema, valueSchema) {
fm, stateLocation, startVersion, endVersion, compressionCodec) {
override protected var changelogSuffix: String = "delta"

private var currentChangelogReader: StateStoreChangelogReader = null
Expand Down Expand Up @@ -125,3 +122,51 @@ class HDFSBackedStateStoreCDCReader(
}
}
}

class RocksDBStateStoreCDCReader(
fm: CheckpointFileManager,
stateLocation: Path,
startVersion: Long,
endVersion: Long,
compressionCodec: CompressionCodec,
keySchema: StructType,
valueSchema: StructType
)
extends StateStoreCDCReader(
fm, stateLocation, startVersion, endVersion, compressionCodec) {
override protected var changelogSuffix: String = "changelog"

private var currentChangelogReader: StateStoreChangelogReader = null

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
while (currentChangelogReader == null || !currentChangelogReader.hasNext) {
if (currentChangelogReader != null) {
currentChangelogReader.close()
currentChangelogReader = null
}
if (!fileIterator.hasNext) {
finished = true
return null
}
currentChangelogReader =
new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec)
}

val readResult = currentChangelogReader.next()
val keyRow = new UnsafeRow(keySchema.fields.length)
keyRow.pointTo(readResult._2, readResult._2.length)
val valueRow = new UnsafeRow(valueSchema.fields.length)
// If valueSize in existing file is not multiple of 8, floor it to multiple of 8.
// This is a workaround for the following:
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(readResult._3, (readResult._3.length / 8) * 8)
(readResult._1, keyRow, valueRow, fileIterator.getVersion - 1)
}

override def close(): Unit = {
if (currentChangelogReader != null) {
currentChangelogReader.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ class SymmetricHashJoinStateManager(
private val keySchema = StructType(
joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
private val keyAttributes = toAttributes(keySchema)
private val keyToNumValues = new KeyToNumValuesStore()
private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)
private lazy val keyToNumValues = new KeyToNumValuesStore()
private lazy val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)

// Clean up any state store resources if necessary at the end of the task
Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } }
Expand Down Expand Up @@ -476,6 +476,16 @@ class SymmetricHashJoinStateManager(

def metrics: StateStoreMetrics = stateStore.metrics

private def initializeStateStoreProvider(keySchema: StructType, valueSchema: StructType):
Unit = {
val storeProviderId = StateStoreProviderId(
stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
stateStoreProvider = StateStoreProvider.createAndInit(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false)
}

/** Get the StateStore with the given schema */
protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
val storeProviderId = StateStoreProviderId(
Expand All @@ -488,10 +498,7 @@ class SymmetricHashJoinStateManager(
stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf)
} else {
// This class will manage the state store provider by itself.
stateStoreProvider = StateStoreProvider.createAndInit(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false)
initializeStateStoreProvider(keySchema, valueSchema)
if (snapshotStartVersion.isDefined) {
stateStoreProvider.getStore(snapshotStartVersion.get, stateInfo.get.storeVersion)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.io.File

import org.scalatest.Assertions

import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf



class HDFSBackedStateDataSourceReadCDCSuite extends StateDataSourceCDCReadSuite {
override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
new HDFSBackedStateStoreProvider

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
newStateStoreProvider().getClass.getName)
// make sure we have a snapshot for every two delta files
// HDFS maintenance task will not count the latest delta file, which has the same version
// as the snapshot version
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 1)
}
}

class RocksDBStateDataSourceCDCReadSuite extends StateDataSourceCDCReadSuite {
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
newStateStoreProvider().getClass.getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"false")
}
}

class RocksDBWithChangelogCheckpointStateDataSourceCDCReaderSuite extends
StateDataSourceCDCReadSuite {
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
newStateStoreProvider().getClass.getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"true")
// make sure we have a snapshot for every other checkpoint
// RocksDB maintenance task will count the latest checkpoint, so we need to set it to 2
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 2)
}
}

abstract class StateDataSourceCDCReadSuite extends StateDataSourceTestBase with Assertions {
protected def newStateStoreProvider(): StateStoreProvider

test("cdc read limit state") {
withTempDir(tempDir => {
val tempDir2 = new File("/tmp/state/rand")
import testImplicits._
spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, 500)
val inputData = MemoryStream[Int]
val df = inputData.toDF().limit(10)
testStream(df)(
StartStream(checkpointLocation = tempDir2.getAbsolutePath),
AddData(inputData, 1, 2, 3, 4),
CheckLastBatch(1, 2, 3, 4),
AddData(inputData, 5, 6, 7, 8),
CheckLastBatch(5, 6, 7, 8),
AddData(inputData, 9, 10, 11, 12),
CheckLastBatch(9, 10)
)

val stateDf = spark.read.format("statestore")
.option(StateSourceOptions.MODE_TYPE, "cdc")
.option(StateSourceOptions.CDC_START_BATCH_ID, 0)
.option(StateSourceOptions.CDC_END_BATCH_ID, 2)
.load(tempDir2.getAbsolutePath)
stateDf.show()

val expectedDf = spark.createDataFrame()
})
}

test("cdc read aggregate state") {

}

test("cdc read deduplication state") {

}

test("cdc read stream-stream join state") {

}
}

0 comments on commit 6922595

Please sign in to comment.