Skip to content

Commit

Permalink
Address Jungtaeks' comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anishshri-db committed Feb 1, 2025
1 parent 275b028 commit b62a4b0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -315,10 +315,11 @@ trait FlatMapGroupsWithStateExecBase
val groupedInitialStateIter =
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)

if (skipEmittingInitialStateKeys) {
// If we are skipping emitting initial state keys, we can just process the initial state
// rows to populate the state store and then process the child data rows.
groupedInitialStateIter.foreach { case (keyRow, initialStateRowIter) =>
// Create a CoGroupedIterator that will group the two iterators together for every
// key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
Expand All @@ -329,40 +330,19 @@ trait FlatMapGroupsWithStateExecBase
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
}

groupedChildDataIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false)
}
} else {
// Create a CoGroupedIterator that will group the two iterators together for every
// key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
if (foundInitialStateForKey) {
FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) {
// If the user has specified to skip emitting the keys that only have initial state
// and no data, then we should not call the function for such keys.
Iterator.empty
} else {
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
valueRowIter,
hasTimedOut = false)
}
}
}
}

/** Find the groups that have timeout set and are timing out right now, and call the function */
def processTimedOutState(): Iterator[InternalRow] = {
Expand Down Expand Up @@ -568,38 +548,36 @@ object FlatMapGroupsWithStateExec {
initialState: SparkPlan,
child: SparkPlan): SparkPlan = {
if (hasInitialState) {
// we wont support skipping emitting initial state keys for batch queries
// since the underlying CoGroupExec does not support it
if (skipEmittingInitialStateKeys) {
throw SparkUnsupportedOperationException()
}

val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
if (skipEmittingInitialStateKeys && values.isEmpty) {
Iterator.empty
} else {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
}
}
CoGroupExec(
func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,36 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest
}
}

Seq(true, false).foreach { skipEmittingInitialStateKeys =>
testWithAllStateVersions("flatMapGroupsWithState - batch query and " +
s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key ->
skipEmittingInitialStateKeys.toString) {
val initialState = Seq(
("apple", 1L),
("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map(x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = Seq("orange", "mango")
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
val df = result.toDF()
if (skipEmittingInitialStateKeys) {
checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF())
} else {
checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF())
}
}
}
}

def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
test(s"$name - state format version $version") {
Expand Down

0 comments on commit b62a4b0

Please sign in to comment.