diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 45c7c9f8ac689..5fe3b0f82a0a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -315,10 +315,11 @@ trait FlatMapGroupsWithStateExecBase val groupedInitialStateIter = GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output) - if (skipEmittingInitialStateKeys) { - // If we are skipping emitting initial state keys, we can just process the initial state - // rows to populate the state store and then process the child data rows. - groupedInitialStateIter.foreach { case (keyRow, initialStateRowIter) => + // Create a CoGroupedIterator that will group the two iterators together for every + // key group. + new CoGroupedIterator( + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + case (keyRow, valueRowIter, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false initialStateRowIter.foreach { initialStateRow => @@ -329,40 +330,19 @@ trait FlatMapGroupsWithStateExecBase val initStateObj = getStateObj.get(initialStateRow) stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) } - } - groupedChildDataIter.flatMap { case (keyRow, valueRowIter) => - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - callFunctionAndUpdateState( - stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false) - } - } else { - // Create a CoGroupedIterator that will group the two iterators together for every - // key group. - new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { - case (keyRow, valueRowIter, initialStateRowIter) => - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - var foundInitialStateForKey = false - initialStateRowIter.foreach { initialStateRow => - if (foundInitialStateForKey) { - FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException() - } - foundInitialStateForKey = true - val initStateObj = getStateObj.get(initialStateRow) - stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) - } - // We apply the values for the key after applying the initial state. + if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) { + // If the user has specified to skip emitting the keys that only have initial state + // and no data, then we should not call the function for such keys. + Iterator.empty + } else { callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false - ) + valueRowIter, + hasTimedOut = false) + } } } - } /** Find the groups that have timeout set and are timing out right now, and call the function */ def processTimedOutState(): Iterator[InternalRow] = { @@ -568,38 +548,36 @@ object FlatMapGroupsWithStateExec { initialState: SparkPlan, child: SparkPlan): SparkPlan = { if (hasInitialState) { - // we wont support skipping emitting initial state keys for batch queries - // since the underlying CoGroupExec does not support it - if (skipEmittingInitialStateKeys) { - throw SparkUnsupportedOperationException() - } - val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => { - // Check if there is only one state for every key. - var foundInitialStateForKey = false - val optionalStates = states.map { stateValue => - if (foundInitialStateForKey) { - foundDuplicateInitialKeyException() - } - foundInitialStateForKey = true - stateValue - }.toArray - - // Create group state object - val groupState = GroupStateImpl.createForStreaming( - optionalStates.headOption, - System.currentTimeMillis, - GroupStateImpl.NO_TIMESTAMP, - timeoutConf, - hasTimedOut = false, - watermarkPresent) - - // Call user function with the state and values for this key - userFunc(keyRow, values, groupState) + if (skipEmittingInitialStateKeys && values.isEmpty) { + Iterator.empty + } else { + // Check if there is only one state for every key. + var foundInitialStateForKey = false + val optionalStates = states.map { stateValue => + if (foundInitialStateForKey) { + foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + stateValue + }.toArray + + // Create group state object + val groupState = GroupStateImpl.createForStreaming( + optionalStates.headOption, + System.currentTimeMillis, + GroupStateImpl.NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false, + watermarkPresent) + + // Call user function with the state and values for this key + userFunc(keyRow, values, groupState) + } } CoGroupExec( func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index 9d9e144d800a0..dd4e3615d43ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -450,6 +450,36 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest } } + Seq(true, false).foreach { skipEmittingInitialStateKeys => + testWithAllStateVersions("flatMapGroupsWithState - batch query and " + + s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> + skipEmittingInitialStateKeys.toString) { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map(x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = Seq("orange", "mango") + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + val df = result.toDF() + if (skipEmittingInitialStateKeys) { + checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF()) + } else { + checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF()) + } + } + } + } + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") {