Skip to content

Commit

Permalink
[SPARK-47746] Implement ordinal-based range encoding in the RocksDBSt…
Browse files Browse the repository at this point in the history
…ateEncoder

### What changes were proposed in this pull request?

The RocksDBStateEncoder now implements range projection by reading a list of ordering ordinals, and using that to project certain columns, in big-endian, to the front of the `Array[Byte]` encoded rows returned by the encoder.

### Why are the changes needed?

StateV2 implementations (and other state-related operators) project certain columns to the front of `UnsafeRow`s, and then rely on the RocksDBStateEncoder to range-encode those columns. We can avoid the initial projection by just passing the RocksDBStateEncoder the ordinals to encode at the front. This should avoid any GC or codegen overheads associated with projection.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New UTs. All existing UTs should pass.

### Was this patch authored or co-authored using generative AI tooling?

Yes

Closes #45905 from neilramaswamy/spark-47746.

Authored-by: Neil Ramaswamy <neil.ramaswamy@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
neilramaswamy authored and HeartSaVioR committed Apr 8, 2024
1 parent 29d077f commit 60806c6
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 60 deletions.
2 changes: 1 addition & 1 deletion common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3630,7 +3630,7 @@
},
"STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN" : {
"message" : [
"Incorrect number of ordering columns=<numOrderingCols> for range scan encoder. Ordering columns cannot be zero or greater than num of schema columns."
"Incorrect number of ordering ordinals=<numOrderingCols> for range scan encoder. The number of ordering ordinals cannot be zero or greater than number of schema columns."
],
"sqlState" : "42802"
},
Expand Down
2 changes: 1 addition & 1 deletion docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2236,7 +2236,7 @@ Please only use the StatefulProcessor within the transformWithState operator.

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Incorrect number of ordering columns=`<numOrderingCols>` for range scan encoder. Ordering columns cannot be zero or greater than num of schema columns.
Incorrect number of ordering ordinals=`<numOrderingCols>` for range scan encoder. The number of ordering ordinals cannot be zero or greater than number of schema columns.

### STATE_STORE_INCORRECT_NUM_PREFIX_COLS_FOR_PREFIX_SCAN

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ abstract class SingleKeyTTLStateImpl(
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))

store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, TTL_VALUE_ROW_SCHEMA,
RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, 1), isInternal = true)
RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = true)

def upsertTTLForStateKey(
expirationMs: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TimerStateImpl(

private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
useMultipleValuesPerKey = false, isInternal = true)

private def getGroupingKey(cfName: String): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ object RocksDBStateEncoder {
case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)

case RangeKeyScanStateEncoderSpec(keySchema, numOrderingCols) =>
new RangeKeyScanStateEncoder(keySchema, numOrderingCols)
case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
new RangeKeyScanStateEncoder(keySchema, orderingOrdinals)

case _ =>
throw new IllegalArgumentException(s"Unsupported key state encoder spec: " +
Expand Down Expand Up @@ -204,18 +204,18 @@ class PrefixKeyScanStateEncoder(
/**
* RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size fields
*
* To encode a row for range scan, we first project the first numOrderingCols needed
* for the range scan into an UnsafeRow; we then rewrite that UnsafeRow's fields in BIG_ENDIAN
* To encode a row for range scan, we first project the orderingOrdinals from the oridinal
* UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's fields in BIG_ENDIAN
* to allow for scanning keys in sorted order using the byte-wise comparison method that
* RocksDB uses.
*
* Then, for the rest of the fields, we project those into another UnsafeRow.
* We then effectively join these two UnsafeRows together, and finally take those bytes
* to get the resulting row.
*
* We cannot support variable sized fields given the UnsafeRow format which stores variable
* sized fields as offset and length pointers to the actual values, thereby changing the required
* ordering.
* We cannot support variable sized fields in the range scan because the UnsafeRow format
* stores variable sized fields as offset and length pointers to the actual values,
* thereby changing the required ordering.
*
* Note that we also support "null" values being passed for these fixed size fields. We prepend
* a single byte to indicate whether the column value is null or not. We cannot change the
Expand All @@ -229,16 +229,19 @@ class PrefixKeyScanStateEncoder(
* here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
*
* @param keySchema - schema of the key to be encoded
* @param numOrderingCols - number of columns to be used for range scan
* @param orderingOrdinals - the ordinals for which the range scan is constructed
*/
class RangeKeyScanStateEncoder(
keySchema: StructType,
numOrderingCols: Int) extends RocksDBKeyStateEncoder {
orderingOrdinals: Seq[Int]) extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

private val rangeScanKeyFieldsWithIdx: Seq[(StructField, Int)] = {
keySchema.zipWithIndex.take(numOrderingCols)
private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
orderingOrdinals.map { ordinal =>
val field = keySchema(ordinal)
(field, ordinal)
}
}

private def isFixedSize(dataType: DataType): Boolean = dataType match {
Expand All @@ -248,34 +251,56 @@ class RangeKeyScanStateEncoder(
}

// verify that only fixed sized columns are used for ordering
rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
if (!isFixedSize(field.dataType)) {
// NullType is technically fixed size, but not supported for ordering
if (field.dataType == NullType) {
throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, idx.toString)
throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, ordinal.toString)
} else {
throw StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, idx.toString)
throw StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, ordinal.toString)
}
}
}

private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
keySchema.zipWithIndex.drop(numOrderingCols)
private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
val field = keySchema(ordinal)
(field, ordinal)
}
}

private val rangeScanKeyProjection: UnsafeProjection = {
val refs = rangeScanKeyFieldsWithIdx.map(x =>
val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
BoundReference(x._2, x._1.dataType, x._1.nullable))
UnsafeProjection.create(refs)
}

private val remainingKeyProjection: UnsafeProjection = {
val refs = remainingKeyFieldsWithIdx.map(x =>
val refs = remainingKeyFieldsWithOrdinal.map(x =>
BoundReference(x._2, x._1.dataType, x._1.nullable))
UnsafeProjection.create(refs)
}

private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema)
// The original schema that we might get could be:
// [foo, bar, baz, buzz]
// We might order by bar and buzz, leading to:
// [bar, buzz, foo, baz]
// We need to create a projection that sends, for example, the buzz at index 1 to index
// 3. Thus, for every record in the original schema, we compute where it would be in
// the joined row and created a projection based on that.
private val restoreKeyProjection: UnsafeProjection = {
val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal)) {
orderingOrdinals.indexOf(originalOrdinal)
} else {
orderingOrdinals.length +
remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
}

BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
}
UnsafeProjection.create(refs)
}

// Reusable objects
private val joinedRowOnKey = new JoinedRow()
Expand Down Expand Up @@ -307,9 +332,10 @@ class RangeKeyScanStateEncoder(
// the sorting order on iteration.
// Also note that the same byte is used to indicate whether the value is negative or not.
private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
val writer = new UnsafeRowWriter(numOrderingCols)
val writer = new UnsafeRowWriter(orderingOrdinals.length)
writer.resetRowWriter()
rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) =>
val field = fieldWithOrdinal._1
val value = row.get(idx, field.dataType)
// Note that we cannot allocate a smaller buffer here even if the value is null
// because the effective byte array is considered variable size and needs to have
Expand Down Expand Up @@ -413,9 +439,11 @@ class RangeKeyScanStateEncoder(
// actual value.
// For negative float/double values, we need to flip all the bits back to get the original value.
private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
val writer = new UnsafeRowWriter(numOrderingCols)
val writer = new UnsafeRowWriter(orderingOrdinals.length)
writer.resetRowWriter()
rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) =>
val field = fieldWithOrdinal._1

val value = row.getBinary(idx)
val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
bbuf.order(ByteOrder.BIG_ENDIAN)
Expand Down Expand Up @@ -464,10 +492,11 @@ class RangeKeyScanStateEncoder(
}

override def encodeKey(row: UnsafeRow): Array[Byte] = {
// This prefix key has the columns specified by orderingOrdinals
val prefixKey = extractPrefixKey(row)
val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))

val result = if (numOrderingCols < keySchema.length) {
val result = if (orderingOrdinals.length < keySchema.length) {
val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
val encodedBytes = new Array[Byte](rangeScanKeyEncoded.length + remainingEncoded.length + 4)
Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, rangeScanKeyEncoded.length)
Expand Down Expand Up @@ -498,10 +527,10 @@ class RangeKeyScanStateEncoder(
Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)

val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded,
numFields = numOrderingCols)
numFields = orderingOrdinals.length)
val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan)

if (numOrderingCols < keySchema.length) {
if (orderingOrdinals.length < keySchema.length) {
// Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes
val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen

Expand All @@ -511,9 +540,11 @@ class RangeKeyScanStateEncoder(
remainingKeyEncodedLen)

val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded,
numFields = keySchema.length - numOrderingCols)
numFields = keySchema.length - orderingOrdinals.length)

restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)
val restored = restoreKeyProjection(joined)
restored
} else {
// if the number of ordering cols is same as the number of key schema cols, we only
// return the prefix key decoded unsafe row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,12 @@ case class PrefixKeyScanStateEncoderSpec(
}
}

/** Encodes rows so that they can be range-scanned based on orderingOrdinals */
case class RangeKeyScanStateEncoderSpec(
keySchema: StructType,
numOrderingCols: Int) extends KeyStateEncoderSpec {
if (numOrderingCols == 0 || numOrderingCols > keySchema.length) {
throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(numOrderingCols.toString)
orderingOrdinals: Seq[Int]) extends KeyStateEncoderSpec {
if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) {
throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString)
}
}

Expand Down
Loading

0 comments on commit 60806c6

Please sign in to comment.