Skip to content

Commit

Permalink
hdfs initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 21, 2024
1 parent fe9cea1 commit 2184396
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ import org.apache.spark.sql.{RuntimeConfig, SparkSession}
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, StateDataSourceModeType}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.StateDataSourceModeType.ModeType
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand Down Expand Up @@ -80,10 +81,21 @@ class StateDataSource extends TableProvider with DataSourceRegister {
manager.readSchemaFile()
}

new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
if (sourceOptions.modeType == StateDataSourceModeType.CDC) {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("operation_type", StringType)
.add("batch_id", LongType)
.add("partition_id", IntegerType)
} else {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
}


} catch {
case NonFatal(e) =>
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
Expand Down Expand Up @@ -118,7 +130,10 @@ case class StateSourceOptions(
storeName: String,
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int]) {
snapshotPartitionId: Option[Int],
modeType: ModeType,
cdcStartBatchID: Option[Long],
cdcEndBatchId: Option[Long]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand All @@ -137,12 +152,31 @@ object StateSourceOptions extends DataSourceOptions {
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")
val MODE_TYPE = newOption("modeType")
val CDC_START_BATCH_ID = newOption("cdcStartBatchId")
val CDC_END_BATCH_ID = newOption("cdcEndBatchId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
val left, right, none = Value
}

object StateDataSourceModeType extends Enumeration {
type ModeType = Value

val NORMAL = Value("normal")
val CDC = Value("cdc")

// Generate record type from byte representation
def getModeTypeFromString(mode: String): ModeType = {
mode match {
case "normal" => NORMAL
case "cdc" => CDC
case _ => throw new RuntimeException(s"Found invalid mode type for value=$mode")
}
}
}

def apply(
sparkSession: SparkSession,
hadoopConf: Configuration,
Expand Down Expand Up @@ -217,9 +251,16 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
}

val modeType = Option(options.get(MODE_TYPE)).map(
StateDataSourceModeType.getModeTypeFromString).getOrElse(StateDataSourceModeType.NORMAL)
val cdcStartBatchId = Option(options.get(CDC_START_BATCH_ID)).map(_.toLong)
val cdcEndBatchId = Option(options.get(CDC_END_BATCH_ID)).map(_.toLong)


StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId)
joinSide, snapshotStartBatchId, snapshotPartitionId,
modeType, cdcStartBatchId, cdcEndBatchId)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.StateDataSourceModeType
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreCDCReader, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SerializableConfiguration

/**
Expand Down Expand Up @@ -101,8 +104,19 @@ class StatePartitionReader(
}
}

private lazy val cdcReader: StateStoreCDCReader = {
provider.getStateStoreCDCReader(
partition.sourceOptions.cdcStartBatchID.get,
partition.sourceOptions.cdcEndBatchId.get)
}

private lazy val iter: Iterator[InternalRow] = {
store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value)))
if (partition.sourceOptions.modeType == StateDataSourceModeType.CDC) {
println("Here!!!!!!")
cdcReader.iterator.map(unifyStateCDCRow)
} else {
store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value)))
}
}

private var current: InternalRow = _
Expand All @@ -122,6 +136,7 @@ class StatePartitionReader(
override def close(): Unit = {
current = null
store.abort()
cdcReader.close()
provider.close()
}

Expand All @@ -132,4 +147,14 @@ class StatePartitionReader(
row.update(2, partition.partition)
row
}

private def unifyStateCDCRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): InternalRow = {
val result = new GenericInternalRow(5)
result.update(0, row._2)
result.update(1, row._3)
result.update(2, UTF8String.fromString(getRecordTypeAsString(row._1)))
result.update(3, row._4)
result.update(4, partition.partition)
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class StateTable(
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

private def isValidSchema(schema: StructType): Boolean = {
return true
if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
new HDFSBackedStateStoreCDCReader(fm, baseDir, startVersion, endVersion,
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
}

// Run bunch of validations specific to HDFSBackedStateStoreProvider
private def runValidation(
useColumnFamilies: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,11 @@ trait StateStoreProvider {
throw new SparkUnsupportedOperationException("getReadStore with startVersion and endVersion " +
s"is not supported by ${this.getClass.toString}")

def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
throw new SparkUnsupportedOperationException("getStateStoreCDCReader is not supported by " +
this.getClass.toString)
}

/** Optional method for providers to allow for background maintenance (e.g. compactions) */
def doMaintenance(): Unit = { }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.state

import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

/**
* Base class for state store changelog reader
* @param fm - checkpoint file manager used to manage streaming query checkpoint
* @param fileToRead - name of file to use to read changelog
* @param compressionCodec - de-compression method using for reading changelog file
*/
abstract class StateStoreCDCReader(
fm: CheckpointFileManager,
// fileToRead: Path,
stateLocation: Path,
startVersion: Long,
endVersion: Long,
compressionCodec: CompressionCodec,
keySchema: StructType,
valueSchema: StructType)
extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging {

class ChangeLogFileIterator(
stateLocation: Path,
startVersion: Long,
endVersion: Long) extends Iterator[Path] {

// assertions
assert(true)

private var currentVersion = startVersion - 1

def getVersion: Long = currentVersion

override def hasNext: Boolean = currentVersion < endVersion

override def next(): Path = {
currentVersion += 1
getChangelogPath(stateLocation, currentVersion)
}
}


// private def decompressStream(inputStream: DataInputStream): DataInputStream = {
// val compressed = compressionCodec.compressedInputStream(inputStream)
// new DataInputStream(compressed)
// }

// private val sourceStream = try {
// fm.open(fileToRead)
// } catch {
// case f: FileNotFoundException =>
// throw QueryExecutionErrors.failedToReadStreamingStateFileError(fileToRead, f)
// }
// protected val input: DataInputStream = decompressStream(sourceStream)


protected lazy val fileIterator =
new ChangeLogFileIterator(stateLocation, startVersion, endVersion)

protected var changelogSuffix: String

private def getChangelogPath(stateLocation: Path, version: Long): Path =
new Path(stateLocation, s"$version.$changelogSuffix")

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long)

def close(): Unit
// = { if (input != null) input.close() }
}

class HDFSBackedStateStoreCDCReader(
fm: CheckpointFileManager,
stateLocation: Path,
startVersion: Long,
endVersion: Long,
compressionCodec: CompressionCodec,
keySchema: StructType,
valueSchema: StructType
)
extends StateStoreCDCReader(
fm, stateLocation, startVersion, endVersion, compressionCodec, keySchema, valueSchema) {
override protected var changelogSuffix: String = "delta"

private var currentChangelogReader: StateStoreChangelogReader = null

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
while (currentChangelogReader == null || !currentChangelogReader.hasNext) {
if (!fileIterator.hasNext) {
finished = true
print("return 1\n")
return null
}
currentChangelogReader =
new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec)
}

print("return 2\n")
val readResult = currentChangelogReader.next()
val keyRow = new UnsafeRow(keySchema.fields.length)
keyRow.pointTo(readResult._2, readResult._2.length)
val valueRow = new UnsafeRow(valueSchema.fields.length)
valueRow.pointTo(readResult._3, readResult._3.length)
(readResult._1, keyRow, valueRow, fileIterator.getVersion - 1)
}

// fix the problem when change if -> while will return null



override def close(): Unit = {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ object RecordType extends Enumeration {
}
}

def getRecordTypeAsString(recordType: RecordType): String = {
recordType match {
case PUT_RECORD => "PUT"
case DELETE_RECORD => "DELETE"
case _ => "UNDEFINED"
}
}

// Generate record type from byte representation
def getRecordTypeFromByte(byte: Byte): RecordType = {
byte match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,20 @@ class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
test("option snapshotPartitionId") {
testSnapshotPartitionId()
}

test("just test") {
val provider = getNewStateStoreProvider("/tmp/spark/state")
.asInstanceOf[HDFSBackedStateStoreProvider]
val reader = provider.getStateStoreCDCReader(1, 4)
println(reader.getNext()) // why is the first element null
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())

}
}

class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite {
Expand Down Expand Up @@ -459,7 +473,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
* @param checkpointDir path to store state information
* @return instance of class extending [[StateStoreProvider]]
*/
private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = {
def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = {
val provider = newStateStoreProvider()
provider.init(
StateStoreId(checkpointDir, 0, 0),
Expand Down

0 comments on commit 2184396

Please sign in to comment.