diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 3dd6bdbe9ae27..2f9de24594b28 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -452,7 +452,7 @@ def close(self): class _GrpcDataChannel(DataChannel): """Base class for implementing a BeamFnData-based DataChannel.""" - _WRITES_FINISHED = object() + _WRITES_FINISHED = beam_fn_api_pb2.Elements.Data() def __init__(self, data_buffer_time_limit_ms=0): # type: (int) -> None @@ -475,7 +475,7 @@ def __init__(self, data_buffer_time_limit_ms=0): def close(self): # type: () -> None - self._to_send.put(self._WRITES_FINISHED) # type: ignore[arg-type] + self._to_send.put(self._WRITES_FINISHED) self._closed = True def wait(self, timeout=None): @@ -639,8 +639,12 @@ def _write_outputs(self): streams = [self._to_send.get()] try: # Coalesce up to 100 other items. - for _ in range(100): - streams.append(self._to_send.get_nowait()) + total_size_bytes = streams[0].ByteSize() + while (total_size_bytes < _DEFAULT_SIZE_FLUSH_THRESHOLD and + len(streams) <= 100): + data_or_timer = self._to_send.get_nowait() + total_size_bytes += data_or_timer.ByteSize() + streams.append(data_or_timer) except queue.Empty: pass if streams[-1] is self._WRITES_FINISHED: