From 65db30323dd3e514e8a9cb76bff409a8303e461e Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 23 Jan 2025 16:30:59 -0800 Subject: [PATCH 1/4] [SPARK-50967] Add option to skip emitting initial state keys within the FMGWS operator --- .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../spark/sql/execution/SparkStrategies.scala | 12 +++- .../FlatMapGroupsInPandasWithStateExec.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 60 +++++++++++++++---- .../FlatMapGroupsWithStateSuite.scala | 1 + ...GroupsWithStateWithInitialStateSuite.scala | 60 +++++++++++++++++++ 6 files changed, 132 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 066006ce7082e..d3b6572caa787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2292,6 +2292,15 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS = + buildConf("spark.sql.streaming.flatMapGroupsWithState.skipEmittingInitialStateKeys") + .internal() + .doc("When true, the flatMapGroupsWithState operation in a streaming query will not emit " + + "results for the initial state keys of each group.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .version("2.0.0") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 36e25773f8342..bf169c4c99ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -736,11 +736,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) => val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val skipEmittingInitialStateKeys = + conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS) val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr, None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, - planLater(initialState), hasInitialState, planLater(child) + planLater(initialState), hasInitialState, skipEmittingInitialStateKeys, planLater(child) ) execPlan :: Nil case _ => @@ -828,7 +830,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, - eventTimeWatermarkForEviction = None, planLater(child) + eventTimeWatermarkForEviction = None, + skipEmittingInitialStateKeys = false, + planLater(child) ) execPlan :: Nil case _ => @@ -953,10 +957,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode, isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs, initialStateDataAttrs, initialStateDeserializer, initialState, child) => + val skipEmittingInitialStateKeys = + conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS) FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries( f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping, initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, - hasInitialState, planLater(initialState), planLater(child) + hasInitialState, skipEmittingInitialStateKeys, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeMode, outputMode, keyEncoder, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index eef0b3e3e8469..76bb164436624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -50,6 +50,7 @@ import org.apache.spark.util.CompletionIterator * @param batchTimestampMs processing timestamp of the current batch. * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys * @param child logical plan of the underlying data */ case class FlatMapGroupsInPandasWithStateExec( @@ -64,6 +65,7 @@ case class FlatMapGroupsInPandasWithStateExec( batchTimestampMs: Option[Long], eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], + skipEmittingInitialStateKeys: Boolean, child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. @@ -137,7 +139,8 @@ case class FlatMapGroupsInPandasWithStateExec( override def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + initStateIter: Iterator[InternalRow], + skipEmittingInitialStateKeys: Boolean): Iterator[InternalRow] = { throw SparkUnsupportedOperationException() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 58d2a19989cbf..45c7c9f8ac689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -52,6 +52,7 @@ trait FlatMapGroupsWithStateExecBase protected val initialStateDataAttrs: Seq[Attribute] protected val initialState: SparkPlan protected val hasInitialState: Boolean + protected val skipEmittingInitialStateKeys: Boolean val stateInfo: Option[StatefulOperatorStateInfo] protected val stateEncoder: ExpressionEncoder[Any] @@ -145,7 +146,8 @@ trait FlatMapGroupsWithStateExecBase val processedOutputIterator = initialStateIterOption match { case Some(initStateIter) if initStateIter.hasNext => - processor.processNewDataWithInitialState(filteredIter, initStateIter) + processor.processNewDataWithInitialState(filteredIter, initStateIter, + skipEmittingInitialStateKeys) case _ => processor.processNewData(filteredIter) } @@ -301,7 +303,8 @@ trait FlatMapGroupsWithStateExecBase */ def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow] + initStateIter: Iterator[InternalRow], + skipEmittingInitialStateKeys: Boolean ): Iterator[InternalRow] = { if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty @@ -312,10 +315,10 @@ trait FlatMapGroupsWithStateExecBase val groupedInitialStateIter = GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output) - // Create a CoGroupedIterator that will group the two iterators together for every key group. - new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { - case (keyRow, valueRowIter, initialStateRowIter) => + if (skipEmittingInitialStateKeys) { + // If we are skipping emitting initial state keys, we can just process the initial state + // rows to populate the state store and then process the child data rows. + groupedInitialStateIter.foreach { case (keyRow, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false initialStateRowIter.foreach { initialStateRow => @@ -326,14 +329,40 @@ trait FlatMapGroupsWithStateExecBase val initStateObj = getStateObj.get(initialStateRow) stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) } - // We apply the values for the key after applying the initial state. + } + + groupedChildDataIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false - ) + valueRowIter, + hasTimedOut = false) + } + } else { + // Create a CoGroupedIterator that will group the two iterators together for every + // key group. + new CoGroupedIterator( + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + case (keyRow, valueRowIter, initialStateRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + var foundInitialStateForKey = false + initialStateRowIter.foreach { initialStateRow => + if (foundInitialStateForKey) { + FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + val initStateObj = getStateObj.get(initialStateRow) + stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) + } + // We apply the values for the key after applying the initial state. + callFunctionAndUpdateState( + stateManager.getState(store, keyUnsafeRow), + valueRowIter, + hasTimedOut = false + ) } } + } /** Find the groups that have timeout set and are timing out right now, and call the function */ def processTimedOutState(): Iterator[InternalRow] = { @@ -388,6 +417,7 @@ trait FlatMapGroupsWithStateExecBase * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param initialState the user specified initial state * @param hasInitialState indicates whether the initial state is provided or not + * @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys * @param child the physical plan for the underlying data */ case class FlatMapGroupsWithStateExec( @@ -410,6 +440,7 @@ case class FlatMapGroupsWithStateExec( eventTimeWatermarkForEviction: Option[Long], initialState: SparkPlan, hasInitialState: Boolean, + skipEmittingInitialStateKeys: Boolean, child: SparkPlan) extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { import GroupStateImpl._ @@ -533,9 +564,16 @@ object FlatMapGroupsWithStateExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, hasInitialState: Boolean, + skipEmittingInitialStateKeys: Boolean, initialState: SparkPlan, child: SparkPlan): SparkPlan = { if (hasInitialState) { + // we wont support skipping emitting initial state keys for batch queries + // since the underlying CoGroupExec does not support it + if (skipEmittingInitialStateKeys) { + throw SparkUnsupportedOperationException() + } + val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index f7ff39622ed40..f1feb62b7622a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1177,6 +1177,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark), RDDScanExec(g, emptyRdd, "rdd"), hasInitialState, + false, RDDScanExec(g, emptyRdd, "rdd")) }.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index 2a2a83d35e1f8..e9470c7ac46ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -351,6 +351,66 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest ) } + testWithAllStateVersions("flatMapGroupsWithState - initial state - " + + s"skipEmittingInitialStateKeys=true") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L), + ("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map( x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "banana"), + CheckNewAnswer(("apple", 2), ("banana", 1)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 3)), + StopStream + ) + } + } + + testWithAllStateVersions("flatMapGroupsWithState - initial state - " + + s"skipEmittingInitialStateKeys=false") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L), + ("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map( x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "banana"), + CheckNewAnswer(("apple", 2), ("banana", 1), ("orange", 2), ("mango", 5)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 3)), + StopStream + ) + } + } + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") { From 78ae35ea0178e949fe31293358e0888e23bf200b Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 23 Jan 2025 16:54:44 -0800 Subject: [PATCH 2/4] Add test --- ...GroupsWithStateWithInitialStateSuite.scala | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index e9470c7ac46ac..653f9873a72b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -351,6 +351,8 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest ) } + // if the keys part of initial state df are different than the keys in the input data, then + // they will not be emitted as part of the result with skipEmittingInitialStateKeys set to true testWithAllStateVersions("flatMapGroupsWithState - initial state - " + s"skipEmittingInitialStateKeys=true") { withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") { @@ -381,6 +383,8 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest } } + // if the keys part of initial state df are different than the keys in the input data, then + // they will be emitted as part of the result with skipEmittedInitialStateKeys set to false testWithAllStateVersions("flatMapGroupsWithState - initial state - " + s"skipEmittingInitialStateKeys=false") { withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") { @@ -411,6 +415,42 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest } } + // if the keys part of the initial state and the first batch are the same, then the result + // is the same irrespective of the value of skipEmittingInitialStateKeys + Seq(true, false).foreach { skipEmittingInitialStateKeys => + testWithAllStateVersions("flatMapGroupsWithState - initial state and initial batch " + + s"have same keys and skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> + skipEmittingInitialStateKeys.toString) { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map(x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "apple"), + AddData(inputData, "orange"), + CheckNewAnswer(("apple", 3), ("orange", 3)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 4)), + StopStream + ) + } + } + } + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") { From 275b0282e259e5f480d6159274364b00d98553ba Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 23 Jan 2025 16:55:47 -0800 Subject: [PATCH 3/4] Fix space --- .../streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index 653f9873a72b9..9d9e144d800a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -450,7 +450,6 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest } } - def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") { From b62a4b077839267f52822a7bd599a68acd8864ae Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 31 Jan 2025 17:42:44 -0800 Subject: [PATCH 4/4] Address Jungtaeks' comments --- .../FlatMapGroupsWithStateExec.scala | 100 +++++++----------- ...GroupsWithStateWithInitialStateSuite.scala | 30 ++++++ 2 files changed, 69 insertions(+), 61 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 45c7c9f8ac689..5fe3b0f82a0a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -315,10 +315,11 @@ trait FlatMapGroupsWithStateExecBase val groupedInitialStateIter = GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output) - if (skipEmittingInitialStateKeys) { - // If we are skipping emitting initial state keys, we can just process the initial state - // rows to populate the state store and then process the child data rows. - groupedInitialStateIter.foreach { case (keyRow, initialStateRowIter) => + // Create a CoGroupedIterator that will group the two iterators together for every + // key group. + new CoGroupedIterator( + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + case (keyRow, valueRowIter, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false initialStateRowIter.foreach { initialStateRow => @@ -329,40 +330,19 @@ trait FlatMapGroupsWithStateExecBase val initStateObj = getStateObj.get(initialStateRow) stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) } - } - groupedChildDataIter.flatMap { case (keyRow, valueRowIter) => - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - callFunctionAndUpdateState( - stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false) - } - } else { - // Create a CoGroupedIterator that will group the two iterators together for every - // key group. - new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { - case (keyRow, valueRowIter, initialStateRowIter) => - val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] - var foundInitialStateForKey = false - initialStateRowIter.foreach { initialStateRow => - if (foundInitialStateForKey) { - FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException() - } - foundInitialStateForKey = true - val initStateObj = getStateObj.get(initialStateRow) - stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) - } - // We apply the values for the key after applying the initial state. + if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) { + // If the user has specified to skip emitting the keys that only have initial state + // and no data, then we should not call the function for such keys. + Iterator.empty + } else { callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false - ) + valueRowIter, + hasTimedOut = false) + } } } - } /** Find the groups that have timeout set and are timing out right now, and call the function */ def processTimedOutState(): Iterator[InternalRow] = { @@ -568,38 +548,36 @@ object FlatMapGroupsWithStateExec { initialState: SparkPlan, child: SparkPlan): SparkPlan = { if (hasInitialState) { - // we wont support skipping emitting initial state keys for batch queries - // since the underlying CoGroupExec does not support it - if (skipEmittingInitialStateKeys) { - throw SparkUnsupportedOperationException() - } - val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => { - // Check if there is only one state for every key. - var foundInitialStateForKey = false - val optionalStates = states.map { stateValue => - if (foundInitialStateForKey) { - foundDuplicateInitialKeyException() - } - foundInitialStateForKey = true - stateValue - }.toArray - - // Create group state object - val groupState = GroupStateImpl.createForStreaming( - optionalStates.headOption, - System.currentTimeMillis, - GroupStateImpl.NO_TIMESTAMP, - timeoutConf, - hasTimedOut = false, - watermarkPresent) - - // Call user function with the state and values for this key - userFunc(keyRow, values, groupState) + if (skipEmittingInitialStateKeys && values.isEmpty) { + Iterator.empty + } else { + // Check if there is only one state for every key. + var foundInitialStateForKey = false + val optionalStates = states.map { stateValue => + if (foundInitialStateForKey) { + foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + stateValue + }.toArray + + // Create group state object + val groupState = GroupStateImpl.createForStreaming( + optionalStates.headOption, + System.currentTimeMillis, + GroupStateImpl.NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false, + watermarkPresent) + + // Call user function with the state and values for this key + userFunc(keyRow, values, groupState) + } } CoGroupExec( func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index 9d9e144d800a0..dd4e3615d43ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -450,6 +450,36 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest } } + Seq(true, false).foreach { skipEmittingInitialStateKeys => + testWithAllStateVersions("flatMapGroupsWithState - batch query and " + + s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> + skipEmittingInitialStateKeys.toString) { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map(x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = Seq("orange", "mango") + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + val df = result.toDF() + if (skipEmittingInitialStateKeys) { + checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF()) + } else { + checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF()) + } + } + } + } + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") {