From 22eb6c4b0a82b9fcf84fc9952b1f6c41dde9bd8d Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 24 Jul 2024 18:57:45 +0900 Subject: [PATCH] [SPARK-48567][SS][FOLLOWUP] StreamingQuery.lastProgress should return the actual StreamingQueryProgress This reverts commit d067fc6c1635dfe7730223021e912e78637bb791, which reverted 042804ad545c88afe69c149b25baea00fc213708, essentially brings it back. 042804ad545c88afe69c149b25baea00fc213708 failed the 3.5 client <> 4.0 server test, but the test was decided to turned off for cross-version test in https://github.com/apache/spark/pull/47468 ### What changes were proposed in this pull request? This PR is created after discussion in this closed one: https://github.com/apache/spark/pull/46886 I was trying to fix a bug (in connect, query.lastProgress doesn't have `numInputRows`, `inputRowsPerSecond`, and `processedRowsPerSecond`), and we reached the conclusion that what purposed in this PR should be the ultimate fix. In python, for both classic spark and spark connect, the return type of `lastProgress` is `Dict` (and `recentProgress` is `List[Dict]`), but in scala it's the actual `StreamingQueryProgress` object: https://github.com/apache/spark/blob/1a5d22aa2ffe769435be4aa6102ef961c55b9593/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala#L94-L101 This API discrepancy brings some confusion, like in Scala, users can do `query.lastProgress.batchId`, while in Python they have to do `query.lastProgress["batchId"]`. This PR makes `StreamingQuery.lastProgress` to return the actual `StreamingQueryProgress` (and `StreamingQuery.recentProgress` to return `List[StreamingQueryProgress]`). To prevent breaking change, we extend `StreamingQueryProgress` to be a subclass of `dict`, so existing code accessing using dictionary method (e.g. `query.lastProgress["id"]`) is still functional. ### Why are the changes needed? API parity ### Does this PR introduce _any_ user-facing change? Yes, now `StreamingQuery.lastProgress` returns the actual `StreamingQueryProgress` (and `StreamingQuery.recentProgress` returns `List[StreamingQueryProgress]`). ### How was this patch tested? Added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #47470 from WweiL/bring-back-lastProgress. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/streaming/query.py | 9 +- python/pyspark/sql/streaming/listener.py | 228 +++++++++++------- python/pyspark/sql/streaming/query.py | 13 +- .../sql/tests/streaming/test_streaming.py | 44 +++- .../streaming/test_streaming_listener.py | 32 ++- 5 files changed, 227 insertions(+), 99 deletions(-) diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 40c975ea4032c..204d16106482d 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -33,6 +33,7 @@ QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent, + StreamingQueryProgress, ) from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, @@ -110,21 +111,21 @@ def status(self) -> Dict[str, Any]: status.__doc__ = PySparkStreamingQuery.status.__doc__ @property - def recentProgress(self) -> List[Dict[str, Any]]: + def recentProgress(self) -> List[StreamingQueryProgress]: cmd = pb2.StreamingQueryCommand() cmd.recent_progress = True progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json - return [json.loads(p) for p in progress] + return [StreamingQueryProgress.fromJson(json.loads(p)) for p in progress] recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__ @property - def lastProgress(self) -> Optional[Dict[str, Any]]: + def lastProgress(self) -> Optional[StreamingQueryProgress]: cmd = pb2.StreamingQueryCommand() cmd.last_progress = True progress = self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json if len(progress) > 0: - return json.loads(progress[-1]) + return StreamingQueryProgress.fromJson(json.loads(progress[-1])) else: return None diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 2aa63cdb91ab6..6cc2cc3fa2b86 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -397,10 +397,13 @@ def errorClassOnException(self) -> Optional[str]: return self._errorClassOnException -class StreamingQueryProgress: +class StreamingQueryProgress(dict): """ .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Becomes a subclass of dict + Notes ----- This API is evolving. @@ -426,23 +429,25 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): + super().__init__( + id=id, + runId=runId, + name=name, + timestamp=timestamp, + batchId=batchId, + batchDuration=batchDuration, + durationMs=durationMs, + eventTime=eventTime, + stateOperators=stateOperators, + sources=sources, + sink=sink, + numInputRows=numInputRows, + inputRowsPerSecond=inputRowsPerSecond, + processedRowsPerSecond=processedRowsPerSecond, + observedMetrics=observedMetrics, + ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict - self._id: uuid.UUID = id - self._runId: uuid.UUID = runId - self._name: Optional[str] = name - self._timestamp: str = timestamp - self._batchId: int = batchId - self._batchDuration: int = batchDuration - self._durationMs: Dict[str, int] = durationMs - self._eventTime: Dict[str, str] = eventTime - self._stateOperators: List[StateOperatorProgress] = stateOperators - self._sources: List[SourceProgress] = sources - self._sink: SinkProgress = sink - self._numInputRows: int = numInputRows - self._inputRowsPerSecond: float = inputRowsPerSecond - self._processedRowsPerSecond: float = processedRowsPerSecond - self._observedMetrics: Dict[str, Row] = observedMetrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress": @@ -489,9 +494,11 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], sources=[SourceProgress.fromJson(s) for s in j["sources"]], sink=SinkProgress.fromJson(j["sink"]), - numInputRows=j["numInputRows"], - inputRowsPerSecond=j["inputRowsPerSecond"], - processedRowsPerSecond=j["processedRowsPerSecond"], + numInputRows=j["numInputRows"] if "numInputRows" in j else None, + inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None, + processedRowsPerSecond=j["processedRowsPerSecond"] + if "processedRowsPerSecond" in j + else None, observedMetrics={ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows for k, row_dict in j["observedMetrics"].items() @@ -506,7 +513,10 @@ def id(self) -> uuid.UUID: A unique query id that persists across restarts. See py:meth:`~pyspark.sql.streaming.StreamingQuery.id`. """ - return self._id + # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId + # to string. But here they are UUID. + # To prevent breaking change, do not cast them to string when accessed with attribute. + return super().__getitem__("id") @property def runId(self) -> uuid.UUID: @@ -514,21 +524,24 @@ def runId(self) -> uuid.UUID: A query id that is unique for every start/restart. See py:meth:`~pyspark.sql.streaming.StreamingQuery.runId`. """ - return self._runId + # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId + # to string. But here they are UUID. + # To prevent breaking change, do not cast them to string when accessed with attribute. + return super().__getitem__("runId") @property def name(self) -> Optional[str]: """ User-specified name of the query, `None` if not specified. """ - return self._name + return self["name"] @property def timestamp(self) -> str: """ The timestamp to start a query. """ - return self._timestamp + return self["timestamp"] @property def batchId(self) -> int: @@ -538,21 +551,21 @@ def batchId(self) -> int: Similarly, when there is no data to be processed, the batchId will not be incremented. """ - return self._batchId + return self["batchId"] @property def batchDuration(self) -> int: """ The process duration of each batch. """ - return self._batchDuration + return self["batchDuration"] @property def durationMs(self) -> Dict[str, int]: """ The amount of time taken to perform various operations in milliseconds. """ - return self._durationMs + return self["durationMs"] @property def eventTime(self) -> Dict[str, str]: @@ -570,21 +583,21 @@ def eventTime(self) -> Dict[str, str]: All timestamps are in ISO8601 format, i.e. UTC timestamps. """ - return self._eventTime + return self["eventTime"] @property def stateOperators(self) -> List["StateOperatorProgress"]: """ Information about operators in the query that store state. """ - return self._stateOperators + return self["stateOperators"] @property def sources(self) -> List["SourceProgress"]: """ detailed statistics on data being read from each of the streaming sources. """ - return self._sources + return self["sources"] @property def sink(self) -> "SinkProgress": @@ -592,32 +605,41 @@ def sink(self) -> "SinkProgress": A unique query id that persists across restarts. See py:meth:`~pyspark.sql.streaming.StreamingQuery.id`. """ - return self._sink + return self["sink"] @property def observedMetrics(self) -> Dict[str, Row]: - return self._observedMetrics + return self["observedMetrics"] @property def numInputRows(self) -> int: """ The aggregate (across all sources) number of records processed in a trigger. """ - return self._numInputRows + if self["numInputRows"] is not None: + return self["numInputRows"] + else: + return sum(s.numInputRows for s in self.sources) @property def inputRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate of data arriving. """ - return self._inputRowsPerSecond + if self["inputRowsPerSecond"] is not None: + return self["inputRowsPerSecond"] + else: + return sum(s.inputRowsPerSecond for s in self.sources) @property def processedRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate at which Spark is processing data. """ - return self._processedRowsPerSecond + if self["processedRowsPerSecond"] is not None: + return self["processedRowsPerSecond"] + else: + return sum(s.processedRowsPerSecond for s in self.sources) @property def json(self) -> str: @@ -641,14 +663,29 @@ def prettyJson(self) -> str: else: return json.dumps(self._jdict, indent=4) + def __getitem__(self, key: str) -> Any: + # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which casts id and runId + # to string. But here they are UUID. + # To prevent breaking change, also cast them to string when accessed with __getitem__. + if key == "id" or key == "runId": + return str(super().__getitem__(key)) + else: + return super().__getitem__(key) + def __str__(self) -> str: return self.prettyJson + def __repr__(self) -> str: + return self.prettyJson + -class StateOperatorProgress: +class StateOperatorProgress(dict): """ .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Becomes a subclass of dict + Notes ----- This API is evolving. @@ -671,20 +708,22 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): + super().__init__( + operatorName=operatorName, + numRowsTotal=numRowsTotal, + numRowsUpdated=numRowsUpdated, + numRowsRemoved=numRowsRemoved, + allUpdatesTimeMs=allUpdatesTimeMs, + allRemovalsTimeMs=allRemovalsTimeMs, + commitTimeMs=commitTimeMs, + memoryUsedBytes=memoryUsedBytes, + numRowsDroppedByWatermark=numRowsDroppedByWatermark, + numShufflePartitions=numShufflePartitions, + numStateStoreInstances=numStateStoreInstances, + customMetrics=customMetrics, + ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict - self._operatorName: str = operatorName - self._numRowsTotal: int = numRowsTotal - self._numRowsUpdated: int = numRowsUpdated - self._numRowsRemoved: int = numRowsRemoved - self._allUpdatesTimeMs: int = allUpdatesTimeMs - self._allRemovalsTimeMs: int = allRemovalsTimeMs - self._commitTimeMs: int = commitTimeMs - self._memoryUsedBytes: int = memoryUsedBytes - self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark - self._numShufflePartitions: int = numShufflePartitions - self._numStateStoreInstances: int = numStateStoreInstances - self._customMetrics: Dict[str, int] = customMetrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "StateOperatorProgress": @@ -724,51 +763,51 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": @property def operatorName(self) -> str: - return self._operatorName + return self["operatorName"] @property def numRowsTotal(self) -> int: - return self._numRowsTotal + return self["numRowsTotal"] @property def numRowsUpdated(self) -> int: - return self._numRowsUpdated + return self["numRowsUpdated"] @property def allUpdatesTimeMs(self) -> int: - return self._allUpdatesTimeMs + return self["allUpdatesTimeMs"] @property def numRowsRemoved(self) -> int: - return self._numRowsRemoved + return self["numRowsRemoved"] @property def allRemovalsTimeMs(self) -> int: - return self._allRemovalsTimeMs + return self["allRemovalsTimeMs"] @property def commitTimeMs(self) -> int: - return self._commitTimeMs + return self["commitTimeMs"] @property def memoryUsedBytes(self) -> int: - return self._memoryUsedBytes + return self["memoryUsedBytes"] @property def numRowsDroppedByWatermark(self) -> int: - return self._numRowsDroppedByWatermark + return self["numRowsDroppedByWatermark"] @property def numShufflePartitions(self) -> int: - return self._numShufflePartitions + return self["numShufflePartitions"] @property def numStateStoreInstances(self) -> int: - return self._numStateStoreInstances + return self["numStateStoreInstances"] @property - def customMetrics(self) -> Dict[str, int]: - return self._customMetrics + def customMetrics(self) -> dict: + return self["customMetrics"] @property def json(self) -> str: @@ -795,11 +834,17 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson + def __repr__(self) -> str: + return self.prettyJson + -class SourceProgress: +class SourceProgress(dict): """ .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Becomes a subclass of dict + Notes ----- This API is evolving. @@ -818,16 +863,18 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: + super().__init__( + description=description, + startOffset=startOffset, + endOffset=endOffset, + latestOffset=latestOffset, + numInputRows=numInputRows, + inputRowsPerSecond=inputRowsPerSecond, + processedRowsPerSecond=processedRowsPerSecond, + metrics=metrics, + ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict - self._description: str = description - self._startOffset: str = startOffset - self._endOffset: str = endOffset - self._latestOffset: str = latestOffset - self._numInputRows: int = numInputRows - self._inputRowsPerSecond: float = inputRowsPerSecond - self._processedRowsPerSecond: float = processedRowsPerSecond - self._metrics: Dict[str, str] = metrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "SourceProgress": @@ -862,53 +909,53 @@ def description(self) -> str: """ Description of the source. """ - return self._description + return self["description"] @property def startOffset(self) -> str: """ The starting offset for data being read. """ - return self._startOffset + return self["startOffset"] @property def endOffset(self) -> str: """ The ending offset for data being read. """ - return self._endOffset + return self["endOffset"] @property def latestOffset(self) -> str: """ The latest offset from this source. """ - return self._latestOffset + return self["latestOffset"] @property def numInputRows(self) -> int: """ The number of records read from this source. """ - return self._numInputRows + return self["numInputRows"] @property def inputRowsPerSecond(self) -> float: """ The rate at which data is arriving from this source. """ - return self._inputRowsPerSecond + return self["inputRowsPerSecond"] @property def processedRowsPerSecond(self) -> float: """ The rate at which data from this source is being processed by Spark. """ - return self._processedRowsPerSecond + return self["processedRowsPerSecond"] @property - def metrics(self) -> Dict[str, str]: - return self._metrics + def metrics(self) -> dict: + return self["metrics"] @property def json(self) -> str: @@ -935,11 +982,17 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson + def __repr__(self) -> str: + return self.prettyJson + -class SinkProgress: +class SinkProgress(dict): """ .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Becomes a subclass of dict + Notes ----- This API is evolving. @@ -953,11 +1006,13 @@ def __init__( jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: + super().__init__( + description=description, + numOutputRows=numOutputRows, + metrics=metrics, + ) self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict - self._description: str = description - self._numOutputRows: int = numOutputRows - self._metrics: Dict[str, str] = metrics @classmethod def fromJObject(cls, jprogress: "JavaObject") -> "SinkProgress": @@ -982,7 +1037,7 @@ def description(self) -> str: """ Description of the source. """ - return self._description + return self["description"] @property def numOutputRows(self) -> int: @@ -990,11 +1045,11 @@ def numOutputRows(self) -> int: Number of rows written to the sink or -1 for Continuous Mode (temporarily) or Sink V1 (until decommissioned). """ - return self._numOutputRows + return self["numOutputRows"] @property def metrics(self) -> Dict[str, str]: - return self._metrics + return self["metrics"] @property def json(self) -> str: @@ -1021,6 +1076,9 @@ def prettyJson(self) -> str: def __str__(self) -> str: return self.prettyJson + def __repr__(self) -> str: + return self.prettyJson + def _test() -> None: import sys diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index bbce29cb43917..28274e9fadc2a 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -22,7 +22,10 @@ from pyspark.errors.exceptions.captured import ( StreamingQueryException as CapturedStreamingQueryException, ) -from pyspark.sql.streaming.listener import StreamingQueryListener +from pyspark.sql.streaming.listener import ( + StreamingQueryListener, + StreamingQueryProgress, +) if TYPE_CHECKING: from py4j.java_gateway import JavaObject @@ -251,7 +254,7 @@ def status(self) -> Dict[str, Any]: return json.loads(self._jsq.status().json()) @property - def recentProgress(self) -> List[Dict[str, Any]]: + def recentProgress(self) -> List[StreamingQueryProgress]: """ Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The number of progress updates retained for each stream is configured by Spark session @@ -280,10 +283,10 @@ def recentProgress(self) -> List[Dict[str, Any]]: >>> sq.stop() """ - return [json.loads(p.json()) for p in self._jsq.recentProgress()] + return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()] @property - def lastProgress(self) -> Optional[Dict[str, Any]]: + def lastProgress(self) -> Optional[StreamingQueryProgress]: """ Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or None if there were no progress updates @@ -311,7 +314,7 @@ def lastProgress(self) -> Optional[Dict[str, Any]]: """ lastProgress = self._jsq.lastProgress() if lastProgress: - return json.loads(lastProgress.json()) + return StreamingQueryProgress.fromJObject(lastProgress) else: return None diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index e284d052d9ae2..00d1fbf538850 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -29,7 +29,7 @@ class StreamingTestsMixin: def test_streaming_query_functions_basic(self): - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") query = ( df.writeStream.format("memory") .queryName("test_streaming_query_functions_basic") @@ -43,8 +43,8 @@ def test_streaming_query_functions_basic(self): self.assertEqual(query.exception(), None) self.assertFalse(query.awaitTermination(1)) query.processAllAvailable() - recentProgress = query.recentProgress lastProgress = query.lastProgress + recentProgress = query.recentProgress self.assertEqual(lastProgress["name"], query.name) self.assertEqual(lastProgress["id"], query.id) self.assertTrue(any(p == lastProgress for p in recentProgress)) @@ -59,6 +59,46 @@ def test_streaming_query_functions_basic(self): finally: query.stop() + def test_streaming_progress(self): + """ + Should be able to access fields using attributes in lastProgress / recentProgress + e.g. q.lastProgress.id + """ + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + query = df.writeStream.format("noop").start() + try: + query.processAllAvailable() + lastProgress = query.lastProgress + recentProgress = query.recentProgress + self.assertEqual(lastProgress["name"], query.name) + # Return str when accessed using dict get. + self.assertEqual(lastProgress["id"], query.id) + # SPARK-48567 Use attribute to access fields in q.lastProgress + self.assertEqual(lastProgress.name, query.name) + # Return uuid when accessed using attribute. + self.assertEqual(str(lastProgress.id), query.id) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + self.assertTrue(lastProgress.numInputRows > 0) + # Also access source / sink progress with attributes + self.assertTrue(len(lastProgress.sources) > 0) + self.assertTrue(lastProgress.sources[0].numInputRows > 0) + self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0) + self.assertTrue(lastProgress.sink.numOutputRows > 0) + self.assertTrue(lastProgress["sink"]["numOutputRows"] > 0) + # In Python, for historical reasons, changing field value + # in StreamingQueryProgress is allowed. + new_name = "myNewQuery" + lastProgress["name"] = new_name + self.assertEqual(lastProgress.name, new_name) + + except Exception as e: + self.fail( + "Streaming query functions sanity check shouldn't throw any error. " + "Error message: " + str(e) + ) + finally: + query.stop() + def test_streaming_query_name_edge_case(self): # Query name should be None when not specified q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start() diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 71d584a0418de..c3ae62e64cc30 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -227,9 +227,9 @@ def onQueryTerminated(self, event): "my_event", count(lit(1)).alias("rc"), count(col("error")).alias("erc") ) - q = observed_ds.writeStream.format("console").start() + q = observed_ds.writeStream.format("noop").start() - while q.lastProgress is None or q.lastProgress["batchId"] == 0: + while q.lastProgress is None or q.lastProgress.batchId == 0: q.awaitTermination(0.5) time.sleep(5) @@ -241,6 +241,32 @@ def onQueryTerminated(self, event): q.stop() self.spark.streams.removeListener(error_listener) + def test_streaming_progress(self): + try: + # Test a fancier query with stateful operation and observed metrics + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("update") + .trigger(processingTime="5 seconds") + .start() + ) + + while q.lastProgress is None or q.lastProgress.batchId == 0: + q.awaitTermination(0.5) + + q.stop() + + self.check_streaming_query_progress(q.lastProgress, True) + for p in q.recentProgress: + self.check_streaming_query_progress(p, True) + + finally: + q.stop() + class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): def test_number_of_public_methods(self): @@ -355,7 +381,7 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - time.sleep(10) + q.awaitTermination(10) q.stop() # Make sure all events are empty