From 60806c63d97bc35f62a049b2185eb921217904c4 Mon Sep 17 00:00:00 2001 From: Neil Ramaswamy Date: Mon, 8 Apr 2024 17:39:18 +0900 Subject: [PATCH] [SPARK-47746] Implement ordinal-based range encoding in the RocksDBStateEncoder ### What changes were proposed in this pull request? The RocksDBStateEncoder now implements range projection by reading a list of ordering ordinals, and using that to project certain columns, in big-endian, to the front of the `Array[Byte]` encoded rows returned by the encoder. ### Why are the changes needed? StateV2 implementations (and other state-related operators) project certain columns to the front of `UnsafeRow`s, and then rely on the RocksDBStateEncoder to range-encode those columns. We can avoid the initial projection by just passing the RocksDBStateEncoder the ordinals to encode at the front. This should avoid any GC or codegen overheads associated with projection. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UTs. All existing UTs should pass. ### Was this patch authored or co-authored using generative AI tooling? Yes Closes #45905 from neilramaswamy/spark-47746. Authored-by: Neil Ramaswamy Signed-off-by: Jungtaek Lim --- .../main/resources/error/error-classes.json | 2 +- docs/sql-error-conditions.md | 2 +- .../sql/execution/streaming/TTLState.scala | 2 +- .../execution/streaming/TimerStateImpl.scala | 2 +- .../streaming/state/RocksDBStateEncoder.scala | 87 ++++++--- .../streaming/state/StateStore.scala | 7 +- .../state/RocksDBStateStoreSuite.scala | 184 +++++++++++++++--- .../streaming/state/StateStoreSuite.scala | 2 +- 8 files changed, 228 insertions(+), 60 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index f28adaf402303..c3a01e9dcd907 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3630,7 +3630,7 @@ }, "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN" : { "message" : [ - "Incorrect number of ordering columns= for range scan encoder. Ordering columns cannot be zero or greater than num of schema columns." + "Incorrect number of ordering ordinals= for range scan encoder. The number of ordering ordinals cannot be zero or greater than number of schema columns." ], "sqlState" : "42802" }, diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index d8261b8c2765e..1887af2e814be 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2236,7 +2236,7 @@ Please only use the StatefulProcessor within the transformWithState operator. [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -Incorrect number of ordering columns=`` for range scan encoder. Ordering columns cannot be zero or greater than num of schema columns. +Incorrect number of ordering ordinals=`` for range scan encoder. The number of ordering ordinals cannot be zero or greater than number of schema columns. ### STATE_STORE_INCORRECT_NUM_PREFIX_COLS_FOR_PREFIX_SCAN diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 0ae93549b731a..f64c8cc44555f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -93,7 +93,7 @@ abstract class SingleKeyTTLStateImpl( UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, TTL_VALUE_ROW_SCHEMA, - RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, 1), isInternal = true) + RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = true) def upsertTTLForStateKey( expirationMs: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 8d410b677c84b..55acc4953c506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -91,7 +91,7 @@ class TimerStateImpl( private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, - schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1), + schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index e9b910a76148f..80c228d15334d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -52,8 +52,8 @@ object RocksDBStateEncoder { case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) => new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey) - case RangeKeyScanStateEncoderSpec(keySchema, numOrderingCols) => - new RangeKeyScanStateEncoder(keySchema, numOrderingCols) + case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) => + new RangeKeyScanStateEncoder(keySchema, orderingOrdinals) case _ => throw new IllegalArgumentException(s"Unsupported key state encoder spec: " + @@ -204,8 +204,8 @@ class PrefixKeyScanStateEncoder( /** * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size fields * - * To encode a row for range scan, we first project the first numOrderingCols needed - * for the range scan into an UnsafeRow; we then rewrite that UnsafeRow's fields in BIG_ENDIAN + * To encode a row for range scan, we first project the orderingOrdinals from the oridinal + * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's fields in BIG_ENDIAN * to allow for scanning keys in sorted order using the byte-wise comparison method that * RocksDB uses. * @@ -213,9 +213,9 @@ class PrefixKeyScanStateEncoder( * We then effectively join these two UnsafeRows together, and finally take those bytes * to get the resulting row. * - * We cannot support variable sized fields given the UnsafeRow format which stores variable - * sized fields as offset and length pointers to the actual values, thereby changing the required - * ordering. + * We cannot support variable sized fields in the range scan because the UnsafeRow format + * stores variable sized fields as offset and length pointers to the actual values, + * thereby changing the required ordering. * * Note that we also support "null" values being passed for these fixed size fields. We prepend * a single byte to indicate whether the column value is null or not. We cannot change the @@ -229,16 +229,19 @@ class PrefixKeyScanStateEncoder( * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale * * @param keySchema - schema of the key to be encoded - * @param numOrderingCols - number of columns to be used for range scan + * @param orderingOrdinals - the ordinals for which the range scan is constructed */ class RangeKeyScanStateEncoder( keySchema: StructType, - numOrderingCols: Int) extends RocksDBKeyStateEncoder { + orderingOrdinals: Seq[Int]) extends RocksDBKeyStateEncoder { import RocksDBStateEncoder._ - private val rangeScanKeyFieldsWithIdx: Seq[(StructField, Int)] = { - keySchema.zipWithIndex.take(numOrderingCols) + private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { + orderingOrdinals.map { ordinal => + val field = keySchema(ordinal) + (field, ordinal) + } } private def isFixedSize(dataType: DataType): Boolean = dataType match { @@ -248,34 +251,56 @@ class RangeKeyScanStateEncoder( } // verify that only fixed sized columns are used for ordering - rangeScanKeyFieldsWithIdx.foreach { case (field, idx) => + rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) => if (!isFixedSize(field.dataType)) { // NullType is technically fixed size, but not supported for ordering if (field.dataType == NullType) { - throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, idx.toString) + throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, ordinal.toString) } else { - throw StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, idx.toString) + throw StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, ordinal.toString) } } } - private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = { - keySchema.zipWithIndex.drop(numOrderingCols) + private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = { + 0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal => + val field = keySchema(ordinal) + (field, ordinal) + } } private val rangeScanKeyProjection: UnsafeProjection = { - val refs = rangeScanKeyFieldsWithIdx.map(x => + val refs = rangeScanKeyFieldsWithOrdinal.map(x => BoundReference(x._2, x._1.dataType, x._1.nullable)) UnsafeProjection.create(refs) } private val remainingKeyProjection: UnsafeProjection = { - val refs = remainingKeyFieldsWithIdx.map(x => + val refs = remainingKeyFieldsWithOrdinal.map(x => BoundReference(x._2, x._1.dataType, x._1.nullable)) UnsafeProjection.create(refs) } - private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) + // The original schema that we might get could be: + // [foo, bar, baz, buzz] + // We might order by bar and buzz, leading to: + // [bar, buzz, foo, baz] + // We need to create a projection that sends, for example, the buzz at index 1 to index + // 3. Thus, for every record in the original schema, we compute where it would be in + // the joined row and created a projection based on that. + private val restoreKeyProjection: UnsafeProjection = { + val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) => + val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal)) { + orderingOrdinals.indexOf(originalOrdinal) + } else { + orderingOrdinals.length + + remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal) + } + + BoundReference(ordinalInJoinedRow, field.dataType, field.nullable) + } + UnsafeProjection.create(refs) + } // Reusable objects private val joinedRowOnKey = new JoinedRow() @@ -307,9 +332,10 @@ class RangeKeyScanStateEncoder( // the sorting order on iteration. // Also note that the same byte is used to indicate whether the value is negative or not. private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = { - val writer = new UnsafeRowWriter(numOrderingCols) + val writer = new UnsafeRowWriter(orderingOrdinals.length) writer.resetRowWriter() - rangeScanKeyFieldsWithIdx.foreach { case (field, idx) => + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 val value = row.get(idx, field.dataType) // Note that we cannot allocate a smaller buffer here even if the value is null // because the effective byte array is considered variable size and needs to have @@ -413,9 +439,11 @@ class RangeKeyScanStateEncoder( // actual value. // For negative float/double values, we need to flip all the bits back to get the original value. private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = { - val writer = new UnsafeRowWriter(numOrderingCols) + val writer = new UnsafeRowWriter(orderingOrdinals.length) writer.resetRowWriter() - rangeScanKeyFieldsWithIdx.foreach { case (field, idx) => + rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case (fieldWithOrdinal, idx) => + val field = fieldWithOrdinal._1 + val value = row.getBinary(idx) val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]]) bbuf.order(ByteOrder.BIG_ENDIAN) @@ -464,10 +492,11 @@ class RangeKeyScanStateEncoder( } override def encodeKey(row: UnsafeRow): Array[Byte] = { + // This prefix key has the columns specified by orderingOrdinals val prefixKey = extractPrefixKey(row) val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey)) - val result = if (numOrderingCols < keySchema.length) { + val result = if (orderingOrdinals.length < keySchema.length) { val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) val encodedBytes = new Array[Byte](rangeScanKeyEncoded.length + remainingEncoded.length + 4) Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, rangeScanKeyEncoded.length) @@ -498,10 +527,10 @@ class RangeKeyScanStateEncoder( Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded, - numFields = numOrderingCols) + numFields = orderingOrdinals.length) val prefixKeyDecoded = decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan) - if (numOrderingCols < keySchema.length) { + if (orderingOrdinals.length < keySchema.length) { // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen @@ -511,9 +540,11 @@ class RangeKeyScanStateEncoder( remainingKeyEncodedLen) val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, - numFields = keySchema.length - numOrderingCols) + numFields = keySchema.length - orderingOrdinals.length) - restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) + val joined = joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded) + val restored = restoreKeyProjection(joined) + restored } else { // if the number of ordering cols is same as the number of key schema cols, we only // return the prefix key decoded unsafe row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d3b3264b8e3dd..959cbbaef8b02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -301,11 +301,12 @@ case class PrefixKeyScanStateEncoderSpec( } } +/** Encodes rows so that they can be range-scanned based on orderingOrdinals */ case class RangeKeyScanStateEncoderSpec( keySchema: StructType, - numOrderingCols: Int) extends KeyStateEncoderSpec { - if (numOrderingCols == 0 || numOrderingCols > keySchema.length) { - throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(numOrderingCols.toString) + orderingOrdinals: Seq[Int]) extends KeyStateEncoderSpec { + if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) { + throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 16a5935e04f4b..f3eb8a392d040 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.util.UUID +import scala.collection.immutable import scala.util.Random import org.apache.hadoop.conf.Configuration @@ -166,7 +167,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // zero ordering cols val ex1 = intercept[SparkUnsupportedOperationException] { tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq()), colFamiliesEnabled)) { provider => provider.getStore(0) } @@ -180,10 +181,12 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid matchPVals = true ) - // ordering cols greater than schema cols + // ordering ordinals greater than schema cols val ex2 = intercept[SparkUnsupportedOperationException] { tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, keySchemaWithRangeScan.length + 1), + RangeKeyScanStateEncoderSpec( + keySchemaWithRangeScan, + 0.to(keySchemaWithRangeScan.length)), colFamiliesEnabled)) { provider => provider.getStore(0) } @@ -205,7 +208,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val ex = intercept[SparkUnsupportedOperationException] { tryWithProviderResource(newStoreProvider(keySchemaWithVariableSizeCols, - RangeKeyScanStateEncoderSpec(keySchemaWithVariableSizeCols, 1), + RangeKeyScanStateEncoderSpec(keySchemaWithVariableSizeCols, Seq(0)), colFamiliesEnabled)) { provider => provider.getStore(0) } @@ -221,6 +224,46 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid ) } + testWithColumnFamilies("rocksdb range scan validation - variable size data types unsupported", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + val keySchemaWithSomeUnsupportedTypeCols: StructType = StructType(Seq( + StructField("key1", StringType, false), + StructField("key2", IntegerType, false), + StructField("key3", FloatType, false), + StructField("key4", BinaryType, false) + )) + val allowedRangeOrdinals = Seq(1, 2) + + keySchemaWithSomeUnsupportedTypeCols.fields.zipWithIndex.foreach { case (field, index) => + val isAllowed = allowedRangeOrdinals.contains(index) + + val getStore = () => { + tryWithProviderResource(newStoreProvider(keySchemaWithSomeUnsupportedTypeCols, + RangeKeyScanStateEncoderSpec(keySchemaWithSomeUnsupportedTypeCols, Seq(index)), + colFamiliesEnabled)) { provider => + provider.getStore(0) + } + } + + if (isAllowed) { + getStore() + } else { + val ex = intercept[SparkUnsupportedOperationException] { + getStore() + } + checkError( + ex, + errorClass = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED", + parameters = Map( + "fieldName" -> field.name, + "index" -> index.toString + ), + matchPVals = true + ) + } + } + } + testWithColumnFamilies("rocksdb range scan validation - null type columns", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => val keySchemaWithNullTypeCols: StructType = StructType( @@ -228,7 +271,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val ex = intercept[SparkUnsupportedOperationException] { tryWithProviderResource(newStoreProvider(keySchemaWithNullTypeCols, - RangeKeyScanStateEncoderSpec(keySchemaWithNullTypeCols, 1), + RangeKeyScanStateEncoderSpec(keySchemaWithNullTypeCols, Seq(0)), colFamiliesEnabled)) { provider => provider.getStore(0) } @@ -248,7 +291,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => val store = provider.getStore(0) // use non-default col family if column families are enabled @@ -256,7 +300,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1)) + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) } val timerTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, @@ -305,14 +349,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val schemaProj = UnsafeProjection.create(Array[DataType](DoubleType, StringType)) tryWithProviderResource(newStoreProvider(testSchema, - RangeKeyScanStateEncoderSpec(testSchema, 1), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(testSchema, Seq(0)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 1)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0))) } // Verify that the sort ordering here is as follows: @@ -355,14 +399,15 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1)) + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) } val timerTimestamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 3L, 35L, 6L, 9L, 5L, @@ -415,14 +460,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val schemaProj = UnsafeProjection.create(Array[DataType](LongType, IntegerType, StringType)) tryWithProviderResource(newStoreProvider(testSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1))) } val timerTimestamps = Seq((931L, 10), (8000L, 40), (452300L, 1), (4200L, 68), (90L, 2000), @@ -447,6 +492,96 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering columns", + TestWithBothChangelogCheckpointingEnabledAndDisabled ) { colFamiliesEnabled => + val testSchema: StructType = StructType( + Seq( + StructField("ordering-1", LongType, false), + StructField("key2", StringType, false), + StructField("ordering-2", IntegerType, false), + StructField("string-2", StringType, false), + StructField("ordering-3", DoubleType, false) + ) + ) + + val testSchemaProj = UnsafeProjection.create(Array[DataType]( + immutable.ArraySeq.unsafeWrapArray(testSchema.fields.map(_.dataType)): _*)) + val rangeScanOrdinals = Seq(0, 2, 4) + + tryWithProviderResource( + newStoreProvider( + testSchema, + RangeKeyScanStateEncoderSpec(testSchema, rangeScanOrdinals), + colFamiliesEnabled + ) + ) { provider => + val store = provider.getStore(0) + + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent( + cfName, + testSchema, + valueSchema, + RangeKeyScanStateEncoderSpec(testSchema, rangeScanOrdinals) + ) + } + + val orderedInput = Seq( + // Make sure that the first column takes precedence, even if the + // later columns are greater + (-2L, 0, 99.0), + (-1L, 0, 98.0), + (0L, 0, 97.0), + (2L, 0, 96.0), + // Make sure that the second column takes precedence, when the first + // column is all the same + (3L, -2, -1.0), + (3L, -1, -2.0), + (3L, 0, -3.0), + (3L, 2, -4.0), + // Finally, make sure that the third column takes precedence, when the + // first two ordering columns are the same. + (4L, -1, -127.0), + (4L, -1, 0.0), + (4L, -1, 64.0), + (4L, -1, 127.0) + ) + val scrambledInput = Random.shuffle(orderedInput) + + scrambledInput.foreach { record => + val keyRow = testSchemaProj.apply( + new GenericInternalRow( + Array[Any]( + record._1, + UTF8String.fromString(Random.alphanumeric.take(Random.nextInt(20) + 1).mkString), + record._2, + UTF8String.fromString(Random.alphanumeric.take(Random.nextInt(20) + 1).mkString), + record._3 + ) + ) + ) + + // The value is just a "dummy" value of 1 + val valueRow = dataToValueRow(1) + store.put(keyRow, valueRow, cfName) + assert(valueRowToData(store.get(keyRow, cfName)) === 1) + } + + val result = store + .iterator(cfName) + .map { kv => + val keyRow = kv.key + val key = (keyRow.getLong(0), keyRow.getInt(2), keyRow.getDouble(4)) + (key._1, key._2, key._3) + } + .toSeq + + assert(result === orderedInput) + } + } + + testWithColumnFamilies("rocksdb range scan multiple ordering columns - variable size " + s"non-ordering columns with null values in first ordering column", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => @@ -459,14 +594,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val schemaProj = UnsafeProjection.create(Array[DataType](LongType, IntegerType, StringType)) tryWithProviderResource(newStoreProvider(testSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1))) } val timerTimestamps = Seq((931L, 10), (null, 40), (452300L, 1), @@ -522,7 +657,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid if (colFamiliesEnabled) { store1.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1))) } val timerTimestamps1 = Seq((null, 3), (null, 1), (null, 32), @@ -559,14 +694,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val schemaProj = UnsafeProjection.create(Array[DataType](LongType, IntegerType, StringType)) tryWithProviderResource(newStoreProvider(testSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1))) } val timerTimestamps = Seq((931L, 10), (40L, null), (452300L, 1), @@ -612,14 +747,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val schemaProj = UnsafeProjection.create(Array[DataType](ByteType, IntegerType, StringType)) tryWithProviderResource(newStoreProvider(testSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, testSchema, valueSchema, - RangeKeyScanStateEncoderSpec(testSchema, 2)) + RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1))) } val timerTimestamps: Seq[(Byte, Int)] = Seq((0x33, 10), (0x1A, 40), (0x1F, 1), (0x01, 68), @@ -649,13 +784,13 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // use the same schema as value schema for single col key schema tryWithProviderResource(newStoreProvider(valueSchema, - RangeKeyScanStateEncoderSpec(valueSchema, 1), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, valueSchema, valueSchema, - RangeKeyScanStateEncoderSpec(valueSchema, 1)) + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0))) } val timerTimestamps = Seq(931, 8000, 452300, 4200, @@ -690,14 +825,15 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), colFamiliesEnabled)) { provider => + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => val store = provider.getStore(0) val cfName = if (colFamiliesEnabled) "testColFamily" else "default" if (colFamiliesEnabled) { store.createColFamilyIfAbsent(cfName, keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1)) + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) } val timerTimestamps = Seq(931L, -1331L, 8000L, 1L, -244L, -8350L, -55L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 231396aff2226..4523a14ca1ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -200,7 +200,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("running with range scan encoder should fail") { val ex = intercept[SparkUnsupportedOperationException] { tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - keyStateEncoderSpec = RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), + keyStateEncoderSpec = RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), useColumnFamilies = false)) { provider => provider.getStore(0) }