diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 3b6ff4090ced..846007ccaee6 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -110,6 +110,16 @@ def execute( logger.debug("Execution config: %s", self._options) + # Note: DAG must be initialized in order to query num_outputs_total. + # Note: Initialize global progress bar before building the streaming + # topology so bars are created in the same order as they should be + # displayed. This is done to ensure correct ordering within notebooks. + # TODO(zhilong): Implement num_output_rows_total for all + # AllToAllOperators + self._global_info = ProgressBar( + "Running", dag.num_output_rows_total(), unit="row" + ) + # Setup the streaming DAG topology and start the runner thread. self._topology, _ = build_streaming_topology(dag, self._options) self._resource_manager = ResourceManager( @@ -126,14 +136,6 @@ def execute( self._has_op_completed = {op: False for op in self._topology} - if not isinstance(dag, InputDataBuffer): - # Note: DAG must be initialized in order to query num_outputs_total. - # TODO(zhilong): Implement num_output_rows_total for all - # AllToAllOperators - self._global_info = ProgressBar( - "Running", dag.num_output_rows_total(), unit="row" - ) - self._output_node: OpState = self._topology[dag] StatsManager.register_dataset_to_stats_actor( self._dataset_tag, diff --git a/python/ray/experimental/tqdm_ray.py b/python/ray/experimental/tqdm_ray.py index 436cdcb5c8e3..e1cd885a8b3c 100644 --- a/python/ray/experimental/tqdm_ray.py +++ b/python/ray/experimental/tqdm_ray.py @@ -242,8 +242,12 @@ def update_bar(self, state: ProgressBarState) -> None: def close_bar(self, state: ProgressBarState) -> None: """Remove a bar from this group.""" bar = self.bars_by_uuid[state["uuid"]] + # Note: Hide and then unhide bars to prevent flashing of the + # last bar when we are closing multiple bars sequentially. + instance().hide_bars() bar.close() del self.bars_by_uuid[state["uuid"]] + instance().unhide_bars() def slots_required(self): """Return the number of pos slots we need to accomodate bars in this group."""