Skip to content

Commit

Permalink
[SPARK-49131][SS] TransformWithState should properly set implicit gro…
Browse files Browse the repository at this point in the history
…uping keys even with lazy iterators

### What changes were proposed in this pull request?

These changes ensure that implicit grouping key thread locals are set in two places:

1. When `handleInputRows` is called. This allows for the user to get/set keyed state in the body of `handleInputRows` before they create the iterator that they return (see the UT).
2. When methods on the returned iterator from `handleInputRows` are called.

### Why are the changes needed?

Previously, if `handleInputRows` returned a lazy iterator, then the following would happen:

1. The implicit grouping key was set in `processNewData`
2. `handleInputRows` ran, and returned an iterator, call it `iter`
3. The implicit grouping key was unset
4. When the sink finally  causes the iterator to evaluate, the iterator from `handleInputRows` is invoked, but cannot find the implicit grouping key

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- New UT.
- The new UT was verified to fail on current `master`.
- All existing UTs should pass.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47641 from neilramaswamy/spark-49131.

Authored-by: Neil Ramaswamy <neil.ramaswamy@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
neilramaswamy authored and HeartSaVioR committed Aug 20, 2024
1 parent a6a62e5 commit 5d2d6a3
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{BinaryType, StructType}
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils}
import org.apache.spark.util.{CompletionIterator, NextIterator, SerializableConfiguration, Utils}

/**
* Physical operator for executing `TransformWithState`
Expand Down Expand Up @@ -190,6 +190,44 @@ case class TransformWithStateExec(
groupingAttributes.map(SortOrder(_, Ascending)),
initialStateGroupingAttrs.map(SortOrder(_, Ascending)))

// Wrapper to ensure that the implicit key is set when the methods on the iterator
// are called. We process all the values for a particular key at a time, so we
// only have to set the implicit key when the first call to the iterator is made, and
// we have to remove it when the iterator is closed.
//
// Note: if we ever start to interleave the processing of the iterators we get back
// from handleInputRows (i.e. we don't process each iterator all at once), then this
// iterator will need to set/unset the implicit key every time hasNext/next is called,
// not just at the first and last calls to hasNext.
private def iteratorWithImplicitKeySet(
key: Any,
iter: Iterator[InternalRow],
onClose: () => Unit = () => {}
): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
var hasStarted = false

override protected def getNext(): InternalRow = {
if (!hasStarted) {
hasStarted = true
ImplicitGroupingKeyTracker.setImplicitKey(key)
}

if (!iter.hasNext) {
finished = true
null
} else {
iter.next()
}
}

override protected def close(): Unit = {
onClose()
ImplicitGroupingKeyTracker.removeImplicitKey()
}
}
}

private def handleInputRows(keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow]):
Iterator[InternalRow] = {
val getKeyObj =
Expand All @@ -201,8 +239,14 @@ case class TransformWithStateExec(
val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)

val keyObj = getKeyObj(keyRow) // convert key to objects
ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
val valueObjIter = valueRowIter.map(getValueObj.apply)

// The statefulProcessor's handleInputRows method may create an eager iterator,
// and in that case, the implicit key needs to be set now. However, it could return
// a lazy iterator, in which case the implicit key should be set when the actual
// methods on the iterator are invoked. This is done with the wrapper class
// at the end of this method.
ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
val mappedIterator = statefulProcessor.handleInputRows(
keyObj,
valueObjIter,
Expand All @@ -211,7 +255,8 @@ case class TransformWithStateExec(
getOutputRow(obj)
}
ImplicitGroupingKeyTracker.removeImplicitKey()
mappedIterator

iteratorWithImplicitKeySet(keyObj, mappedIterator)
}

private def processInitialStateRows(
Expand Down Expand Up @@ -263,9 +308,11 @@ case class TransformWithStateExec(
new ExpiredTimerInfoImpl(isValid = true, Some(expiryTimestampMs))).map { obj =>
getOutputRow(obj)
}
processorHandle.deleteTimer(expiryTimestampMs)
ImplicitGroupingKeyTracker.removeImplicitKey()
mappedIterator

iteratorWithImplicitKeySet(keyObj, mappedIterator, () => {
processorHandle.deleteTimer(expiryTimestampMs)
})
}

private def processTimers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,80 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("transformWithState - lazy iterators can properly get/set keyed state") {
class ProcessorWithLazyIterators
extends StatefulProcessor[Long, Long, Long] {
@transient protected var _myValueState: ValueState[Long] = _
var hasSetTimer = false

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_myValueState = getHandle.getValueState[Long](
"myValueState",
Encoders.scalaLong
)
}

override def handleInputRows(
key: Long,
inputRows: Iterator[Long],
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[Long] = {
// Eagerly get/set a state variable
_myValueState.get()
_myValueState.update(1)

// Create a timer (but only once) so that we can test timers have their implicit key set
if (!hasSetTimer) {
getHandle.registerTimer(0)
hasSetTimer = true
}

// In both of these cases, we return a lazy iterator that gets/sets state variables.
// This is to test that the stateful processor can handle lazy iterators.
//
// The timer uses a Seq(42L) since when the timer fires, inputRows is empty.
if (expiredTimerInfo.isValid()) {
Seq(42L).iterator.map { r =>
_myValueState.get()
_myValueState.update(r)
r
}
} else {
inputRows.map { r =>
_myValueState.get()
_myValueState.update(r)
r
}
}
}
}

withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString
) {
val inputData = MemoryStream[Int]
val result = inputData
.toDS()
.select(timestamp_seconds($"value").as("timestamp"))
.withWatermark("timestamp", "10 seconds")
.as[Long]
.groupByKey(x => x)
.transformWithState(
new ProcessorWithLazyIterators(), TimeMode.EventTime(), OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(),
// Use 12 so that the watermark advances to 2 seconds and causes the timer to fire
AddData(inputData, 12),
// The 12 is from the input data; the 42 is from the timer
CheckAnswer(12, 42)
)
}
}

test("transformWithState - streaming with rocksdb should succeed") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
Expand Down

0 comments on commit 5d2d6a3

Please sign in to comment.