From 1d7b6fb214be279aaabf07cdf13ce269366055d7 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 8 Jun 2023 11:41:27 -0700 Subject: [PATCH] fix it Signed-off-by: Eric Liang --- .../iterator/stream_split_iterator.py | 10 +++++++- python/ray/data/tests/test_stats.py | 24 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index ef2adb65df1d..c6c27a7a4694 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -96,7 +96,7 @@ def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: def stats(self) -> str: """Implements DataIterator.""" - return self._base_dataset.stats() + return ray.get(self._coord_actor.stats.remote()) def schema(self) -> Union[type, "pyarrow.lib.Schema"]: """Implements DataIterator.""" @@ -132,6 +132,7 @@ def __init__( self._equal = equal self._locality_hints = locality_hints self._lock = threading.RLock() + self._executor = None # Guarded by self._lock. self._next_bundle: Dict[int, RefBundle] = {} @@ -143,6 +144,7 @@ def gen_epochs(): executor = StreamingExecutor( copy.deepcopy(dataset.context.execution_options) ) + self._executor = executor def add_split_op(dag): return OutputSplitter(dag, n, equal, locality_hints) @@ -159,6 +161,12 @@ def add_split_op(dag): self._next_epoch = gen_epochs() self._output_iterator = None + def stats(self) -> str: + """Returns stats from the base dataset.""" + if self._executor: + return self._executor.get_stats().to_summary().to_string() + return self._base_dataset.stats() + def start_epoch(self, split_idx: int) -> str: """Called to start an epoch. diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index b7459f175c1c..53379962a84b 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -34,6 +34,30 @@ def dummy_map_batches(x): return x +def test_streaming_split_stats(ray_start_regular_shared): + context = DataContext.get_current() + ds = ray.data.range(1000, parallelism=10) + it = ds.map_batches(dummy_map_batches).streaming_split(1)[0] + list(it.iter_batches()) + stats = it.stats() + assert ( + canonicalize(stats) + == """Stage N ReadRange->MapBatches(dummy_map_batches): N/N blocks executed in T +* Remote wall time: T min, T max, T mean, T total +* Remote cpu time: T min, T max, T mean, T total +* Peak heap memory usage (MiB): N min, N max, N mean +* Output num rows: N min, N max, N mean, N total +* Output size bytes: N min, N max, N mean, N total +* Tasks per node: N min, N max, N mean; N nodes used +* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \ +'obj_store_mem_peak': N} + +Stage N split(N, equal=False): +* Extra metrics: {'num_output_N': N} +""" + ) + + def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats): context = DataContext.get_current() context.optimize_fuse_stages = True