From 6f47dd53ceeb5f914c9829afde87875380f39e13 Mon Sep 17 00:00:00 2001 From: amogkam Date: Tue, 21 Mar 2023 18:50:17 -0700 Subject: [PATCH 01/75] update Signed-off-by: amogkam --- python/ray/data/BUILD | 7 ++ .../data/_internal/block_batching/__init__.py | 6 ++ .../{ => block_batching}/block_batching.py | 50 +---------- .../ray/data/_internal/block_batching/util.py | 85 ++++++++++++++++++ .../test_block_batching.py | 61 +------------ .../data/tests/block_batching/test_util.py | 90 +++++++++++++++++++ 6 files changed, 190 insertions(+), 109 deletions(-) create mode 100644 python/ray/data/_internal/block_batching/__init__.py rename python/ray/data/_internal/{ => block_batching}/block_batching.py (89%) create mode 100644 python/ray/data/_internal/block_batching/util.py rename python/ray/data/tests/{ => block_batching}/test_block_batching.py (79%) create mode 100644 python/ray/data/tests/block_batching/test_util.py diff --git a/python/ray/data/BUILD b/python/ray/data/BUILD index 0217c480ab404..f6b2238a5c9de 100644 --- a/python/ray/data/BUILD +++ b/python/ray/data/BUILD @@ -11,6 +11,13 @@ py_library( deps = ["//python/ray/tests:conftest"], ) +py_test_module_list( + files = glob(["tests/block_batching/test_*.py"]), + size = "medium", + tags = ["team:ml", "exclusive"], + deps = ["//:ray_lib", ":conftest"], +) + py_test_module_list( files = glob(["tests/preprocessors/test_*.py"]), size = "small", diff --git a/python/ray/data/_internal/block_batching/__init__.py b/python/ray/data/_internal/block_batching/__init__.py new file mode 100644 index 0000000000000..e7528278877ae --- /dev/null +++ b/python/ray/data/_internal/block_batching/__init__.py @@ -0,0 +1,6 @@ +from ray.data._internal.block_batching.block_batching import ( + batch_blocks, + batch_block_refs, +) + +__all__ = ["batch_blocks", "batch_block_refs"] diff --git a/python/ray/data/_internal/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py similarity index 89% rename from python/ray/data/_internal/block_batching.py rename to python/ray/data/_internal/block_batching/block_batching.py index 168b14137a536..63b3949a02331 100644 --- a/python/ray/data/_internal/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,12 +1,11 @@ import collections import itertools -import queue import sys -import threading from typing import Any, Callable, Iterator, Optional, TypeVar, Union import ray from ray.actor import ActorHandle +from ray.data._internal.block_batching.util import _make_async_gen from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data._internal.memory_tracing import trace_deallocation @@ -175,53 +174,6 @@ def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: yield formatted_batch -def _make_async_gen( - base_iterator: Iterator[T], prefetch_buffer_size: int = 1 -) -> Iterator[T]: - """Returns a new iterator with elements fetched from the base_iterator - in an async fashion using a background thread. - - Args: - base_iterator: The iterator to asynchronously fetch from. - prefetch_buffer_size: The maximum number of items to prefetch. Increasing the - size allows for more computation overlap for very expensive downstream UDFs. - However it comes at the cost of additional memory overhead. Defaults to 1. - - Returns: - An iterator with the same elements as the base_iterator. - """ - - fetch_queue = queue.Queue(maxsize=prefetch_buffer_size) - - sentinel = object() - - def _async_fetch(): - for item in base_iterator: - fetch_queue.put(item, block=True) - - # Indicate done adding items. - fetch_queue.put(sentinel, block=True) - - # Start a background thread which iterates through the base iterator, - # triggering execution and adding results to the queue until it is full. - # Iterating through the iterator returned by this function pulls - # ready items from the queue, allowing the background thread to continue execution. - - fetch_thread = threading.Thread(target=_async_fetch) - fetch_thread.start() - - while True: - next_item = fetch_queue.get(block=True) - if next_item is not sentinel: - yield next_item - fetch_queue.task_done() - if next_item is sentinel: - break - - fetch_queue.join() - fetch_thread.join() - - def _resolve_blocks( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py new file mode 100644 index 0000000000000..ed650766a0c1e --- /dev/null +++ b/python/ray/data/_internal/block_batching/util.py @@ -0,0 +1,85 @@ +import logging +import queue +import threading +from typing import Callable, Iterator, TypeVar + +T = TypeVar("T") +U = TypeVar("U") + +logger = logging.getLogger(__name__) + + +def _make_async_gen( + base_iterator: Iterator[T], + fn: Callable[[Iterator[T]], Iterator[U]], + num_workers: int = 1, +) -> Iterator[U]: + """Returns a new iterator with elements fetched from the base_iterator + in an async fashion using a threadpool. + Each thread in the threadpool will fetch data from the base_iterator in a + thread-safe fashion, and apply the provided computation. triggering the base + iterator's execution. + Args: + base_iterator: The iterator to asynchronously fetch from. + fn: The function to run on the input iterator. + num_workers: The number of threads to use in the threadpool. + Returns: + An iterator with the same elements as the base_iterator. + """ + + def convert_to_threadsafe_iterator(base_iterator: Iterator[T]) -> Iterator[T]: + class ThreadSafeIterator: + def __init__(self, it): + self.lock = threading.Lock() + self.it = it + + def __next__(self): + with self.lock: + return next(self.it) + + def __iter__(self): + return self + + return ThreadSafeIterator(base_iterator) + + thread_safe_generator = convert_to_threadsafe_iterator(base_iterator) + + class Sentinel: + def __init__(self, thread_index: int): + self.thread_index = thread_index + + output_queue = queue.Queue(1) + + def execute_computation(thread_index: int): + try: + for item in fn(thread_safe_generator): + output_queue.put(item, block=True) + output_queue.put(Sentinel(thread_index), block=True) + except Exception as e: + output_queue.put(e, block=True) + + threads = [ + threading.Thread(target=execute_computation, args=(i,), daemon=True) + for i in range(num_workers) + ] + + for thread in threads: + thread.start() + + num_threads_finished = 0 + while True: + next_item = output_queue.get(block=True) + if isinstance(next_item, Exception): + output_queue.task_done() + raise next_item + if isinstance(next_item, Sentinel): + output_queue.task_done() + logger.debug(f"Thread {next_item.thread_index} finished.") + num_threads_finished += 1 + threads[next_item.thread_index].join() + else: + yield next_item + output_queue.task_done() + if num_threads_finished >= num_workers: + output_queue.join() + break diff --git a/python/ray/data/tests/test_block_batching.py b/python/ray/data/tests/block_batching/test_block_batching.py similarity index 79% rename from python/ray/data/tests/test_block_batching.py rename to python/ray/data/tests/block_batching/test_block_batching.py index 3cfdf80eef613..f580a71b24f5b 100644 --- a/python/ray/data/tests/test_block_batching.py +++ b/python/ray/data/tests/block_batching/test_block_batching.py @@ -8,14 +8,13 @@ import pyarrow as pa from ray.data.block import Block -from ray.data._internal.block_batching import ( +from ray.data._internal.block_batching.block_batching import ( BlockPrefetcher, batch_block_refs, batch_blocks, _prefetch_blocks, _blocks_to_batches, _format_batches, - _make_async_gen, ) @@ -123,64 +122,6 @@ def test_format_batches(batch_format): assert isinstance(batch["foo"], np.ndarray) -def test_make_async_gen(): - """Tests that make_async_gen overlaps compute.""" - - num_items = 10 - - def gen(): - for i in range(num_items): - time.sleep(2) - yield i - - def sleep_udf(item): - time.sleep(3) - return item - - iterator = _make_async_gen(gen()) - - start_time = time.time() - outputs = [] - for item in iterator: - outputs.append(sleep_udf(item)) - end_time = time.time() - - assert outputs == list(range(num_items)) - - assert end_time - start_time < num_items * 3 + 3 - - -def test_make_async_gen_buffer_size(): - """Tests that multiple items can be prefetched at a time - with larger buffer size.""" - - num_items = 5 - - def gen(): - for i in range(num_items): - time.sleep(1) - yield i - - def sleep_udf(item): - time.sleep(5) - return item - - iterator = _make_async_gen(gen(), prefetch_buffer_size=4) - - start_time = time.time() - - # Only sleep for first item. - sleep_udf(next(iterator)) - - # All subsequent items should already be prefetched and should be ready. - for _ in iterator: - pass - end_time = time.time() - - # 1 second for first item, 5 seconds for udf, 0.5 seconds buffer - assert end_time - start_time < 6.5 - - # Test for 3 cases # 1. Batch size is less than block size # 2. Batch size is more than block size diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py new file mode 100644 index 0000000000000..47140686de01b --- /dev/null +++ b/python/ray/data/tests/block_batching/test_util.py @@ -0,0 +1,90 @@ +import pytest +import time + +from ray.data._internal.block_batching.util import _make_async_gen + + +def test_make_async_gen_fail(): + """Tests that any errors raised in async threads are propagated to the main + thread.""" + + def gen(base_iterator): + raise ValueError("Fail") + + iterator = _make_async_gen(base_iterator=iter([1]), fn=gen) + + with pytest.raises(ValueError) as e: + for _ in iterator: + pass + + assert e.match("Fail") + + +def test_make_async_gen(): + """Tests that make_async_gen overlaps compute.""" + + num_items = 10 + + def gen(base_iterator): + for i in base_iterator: + time.sleep(2) + yield i + + def sleep_udf(item): + time.sleep(3) + return item + + iterator = _make_async_gen( + base_iterator=iter(range(num_items)), fn=gen, num_workers=1 + ) + + start_time = time.time() + outputs = [] + for item in iterator: + print(item) + outputs.append(sleep_udf(item)) + end_time = time.time() + + assert outputs == list(range(num_items)) + + # Three second buffer. + assert end_time - start_time < num_items * 3 + 3 + + +def test_make_async_gen_multiple_threads(): + """Tests that using multiple threads can overlap compute even more.""" + + num_items = 5 + + def gen(base_iterator): + for i in base_iterator: + time.sleep(4) + yield i + + def sleep_udf(item): + time.sleep(5) + return item + + # All 5 items should be fetched concurrently. + iterator = _make_async_gen( + base_iterator=iter(range(num_items)), fn=gen, num_workers=5 + ) + + start_time = time.time() + + # Only sleep for first item. + sleep_udf(next(iterator)) + + # All subsequent items should already be prefetched and should be ready. + for _ in iterator: + pass + end_time = time.time() + + # 4 second for first item, 5 seconds for udf, 0.5 seconds buffer + assert end_time - start_time < 9.5 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) From da336f3a9e7ba52d49176fe369e6639aaf14b4cf Mon Sep 17 00:00:00 2001 From: amogkam Date: Tue, 21 Mar 2023 23:21:11 -0700 Subject: [PATCH 02/75] fix Signed-off-by: amogkam --- .../data/tests/block_batching/test_block_batching.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/data/tests/block_batching/test_block_batching.py b/python/ray/data/tests/block_batching/test_block_batching.py index f580a71b24f5b..357eff91cc42f 100644 --- a/python/ray/data/tests/block_batching/test_block_batching.py +++ b/python/ray/data/tests/block_batching/test_block_batching.py @@ -25,9 +25,9 @@ def block_generator(num_rows: int, num_blocks: int): def test_batch_block_refs(): with mock.patch( - "ray.data._internal.block_batching._prefetch_blocks" + "ray.data._internal.block_batching.block_batching._prefetch_blocks" ) as mock_prefetch, mock.patch( - "ray.data._internal.block_batching.batch_blocks" + "ray.data._internal.block_batching.block_batching.batch_blocks" ) as mock_batch_blocks: block_iter = block_generator(2, 2) batch_iter = batch_block_refs(block_iter) @@ -39,9 +39,9 @@ def test_batch_block_refs(): def test_batch_blocks(): with mock.patch( - "ray.data._internal.block_batching._blocks_to_batches" + "ray.data._internal.block_batching.block_batching._blocks_to_batches" ) as mock_batch, mock.patch( - "ray.data._internal.block_batching._format_batches" + "ray.data._internal.block_batching.block_batching._format_batches" ) as mock_format: block_iter = block_generator(2, 2) batch_iter = batch_blocks(block_iter) @@ -136,7 +136,8 @@ def sleep_batch_format(batch_iter, *args, **kwargs): yield batch with mock.patch( - "ray.data._internal.block_batching._format_batches", sleep_batch_format + "ray.data._internal.block_batching.block_batching._format_batches", + sleep_batch_format, ): batch_iter = batch_blocks( batch_size=batch_size, blocks=blocks, prefetch_batches=1 From 98c7918220791adb4108a1bc3d3a7d7e906fb8d5 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 11:39:41 -0700 Subject: [PATCH 03/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/block_batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 63b3949a02331..d6485f2ca0c96 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -166,7 +166,7 @@ def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: batch_iter = batch_fn_iter(batch_iter) if prefetch_batches > 0: - batch_iter = _make_async_gen(batch_iter, prefetch_buffer_size=prefetch_batches) + batch_iter = _make_async_gen(batch_iter, num_workers=prefetch_batches) for formatted_batch in batch_iter: user_timer = stats.iter_user_s.timer() if stats else nullcontext() From 18d66b648da4945f32dc4b179bebff6ffc28820b Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 11:48:55 -0700 Subject: [PATCH 04/75] newline Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index ed650766a0c1e..e35a28e595c92 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -19,10 +19,12 @@ def _make_async_gen( Each thread in the threadpool will fetch data from the base_iterator in a thread-safe fashion, and apply the provided computation. triggering the base iterator's execution. + Args: base_iterator: The iterator to asynchronously fetch from. fn: The function to run on the input iterator. num_workers: The number of threads to use in the threadpool. + Returns: An iterator with the same elements as the base_iterator. """ From 02d3f7e209489c02900724c037547af84723339d Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 13:03:16 -0700 Subject: [PATCH 05/75] fix Signed-off-by: amogkam --- .../block_batching/block_batching.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 15ed916a36b80..145556e689093 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -36,7 +36,7 @@ def batch_block_refs( prefetch_blocks: int = 0, clear_block_after_read: bool = False, batch_size: Optional[int] = None, - batch_format: Optional[str] = "default", + batch_format: str = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], Any]] = None, shuffle_buffer_min_size: Optional[int] = None, @@ -128,7 +128,7 @@ def batch_blocks( *, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, batch_size: Optional[int] = None, - batch_format: Optional[str] = "default", + batch_format: str = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None, shuffle_buffer_min_size: Optional[int] = None, @@ -143,30 +143,36 @@ def batch_blocks( This means that this function does not support block prefetching. """ - batch_iter = _format_batches( - _blocks_to_batches( - block_iter=blocks, + def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]: + batch_iter = _format_batches( + _blocks_to_batches( + block_iter=base_iterator, + stats=stats, + batch_size=batch_size, + drop_last=drop_last, + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ensure_copy=ensure_copy, + ), + batch_format=batch_format, stats=stats, - batch_size=batch_size, - drop_last=drop_last, - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ensure_copy=ensure_copy, - ), - batch_format=batch_format, - stats=stats, - ) + ) - if collate_fn is not None: + if collate_fn is not None: - def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: - for batch in iterator: - yield collate_fn(batch) + def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: + for batch in iterator: + yield collate_fn(batch) - batch_iter = batch_fn_iter(batch_iter) + batch_iter = batch_fn_iter(batch_iter) + yield from batch_iter if prefetch_batches > 0: - batch_iter = _make_async_gen(batch_iter, num_workers=prefetch_batches) + batch_iter = _make_async_gen( + blocks, fn=_iterator_fn, num_workers=prefetch_batches + ) + else: + batch_iter = _iterator_fn(blocks) for formatted_batch in batch_iter: user_timer = stats.iter_user_s.timer() if stats else nullcontext() @@ -337,7 +343,7 @@ def get_iter_next_batch_s_timer(): def _format_batches( block_iter: Iterator[Block], - batch_format: Optional[str], + batch_format: str, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[DataBatch]: """Given an iterator of blocks, returns an iterator of formatted batches. From 241d58f6b6b809ca783691e0fa3542e52ebfd218 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 14:30:07 -0700 Subject: [PATCH 06/75] wip Signed-off-by: amogkam --- .../block_batching/block_batching.py | 57 +------------------ .../ray/data/_internal/block_batching/util.py | 45 +++++++++++++++ python/ray/data/_internal/memory_tracing.py | 3 + 3 files changed, 51 insertions(+), 54 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 145556e689093..1c2e0b79a047d 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -4,15 +4,14 @@ from typing import Any, Callable, Iterator, Optional, TypeVar, Union import ray -from ray.actor import ActorHandle -from ray.data._internal.block_batching.util import _make_async_gen +from ray.data._internal.block_batching.interfaces import BlockPrefetcher +from ray.data._internal.block_batching.util import _make_async_gen, WaitBlockPrefetcher, ActorBlockPrefetcher from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data._internal.memory_tracing import trace_deallocation from ray.data.block import Block, BlockAccessor, DataBatch from ray.data.context import DatasetContext from ray.types import ObjectRef -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy T = TypeVar("T") @@ -26,9 +25,6 @@ def nullcontext(enter_result=None): yield enter_result -PREFETCHER_ACTOR_NAMESPACE = "ray.dataset" - - def batch_block_refs( block_refs: Iterator[ObjectRef[Block]], *, @@ -229,7 +225,7 @@ def _resolve_blocks( def _prefetch_blocks( block_ref_iter: Iterator[ObjectRef[Block]], - prefetcher: "BlockPrefetcher", + prefetcher: BlockPrefetcher, num_blocks_to_prefetch: int, clear_block_after_read: bool = False, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, @@ -359,50 +355,3 @@ def _format_batches( with stats.iter_format_batch_s.timer() if stats else nullcontext(): batch = BlockAccessor.for_block(block).to_batch_format(batch_format) yield batch - - -class BlockPrefetcher: - """Interface for prefetching blocks.""" - - def prefetch_blocks(self, blocks: ObjectRef[Block]): - """Prefetch the provided blocks to this node.""" - raise NotImplementedError - - -class WaitBlockPrefetcher(BlockPrefetcher): - """Block prefetcher using ray.wait.""" - - def prefetch_blocks(self, blocks: ObjectRef[Block]): - ray.wait(blocks, num_returns=1, fetch_local=True) - - -# ray.wait doesn't work as expected, so we have an -# actor-based prefetcher as a work around. See -# https://github.com/ray-project/ray/issues/23983 for details. -class ActorBlockPrefetcher(BlockPrefetcher): - """Block prefetcher using a local actor.""" - - def __init__(self): - self.prefetch_actor = self._get_or_create_actor_prefetcher() - - @staticmethod - def _get_or_create_actor_prefetcher() -> "ActorHandle": - node_id = ray.get_runtime_context().node_id - actor_name = f"dataset-block-prefetcher-{node_id}" - return _BlockPretcher.options( - scheduling_strategy=NodeAffinitySchedulingStrategy(node_id, soft=False), - name=actor_name, - namespace=PREFETCHER_ACTOR_NAMESPACE, - get_if_exists=True, - ).remote() - - def prefetch_blocks(self, blocks: ObjectRef[Block]): - self.prefetch_actor.prefetch.remote(*blocks) - - -@ray.remote(num_cpus=0) -class _BlockPretcher: - """Helper actor that prefetches blocks asynchronously.""" - - def prefetch(self, *blocks) -> None: - pass diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index e35a28e595c92..63dd15e42fc9a 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,6 +3,13 @@ import threading from typing import Callable, Iterator, TypeVar +import ray +from ray.actor import ActorHandle +from ray.types import ObjectRef +from ray.data.block import Block +from ray.data._internal.block_batching.interfaces import BlockPrefetcher +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + T = TypeVar("T") U = TypeVar("U") @@ -85,3 +92,41 @@ def execute_computation(thread_index: int): if num_threads_finished >= num_workers: output_queue.join() break + +PREFETCHER_ACTOR_NAMESPACE = "ray.dataset" +class WaitBlockPrefetcher(BlockPrefetcher): + """Block prefetcher using ray.wait.""" + + def prefetch_blocks(self, blocks: ObjectRef[Block]): + ray.wait(blocks, num_returns=1, fetch_local=True) + + +# ray.wait doesn't work as expected, so we have an +# actor-based prefetcher as a work around. See +# https://github.com/ray-project/ray/issues/23983 for details. +class ActorBlockPrefetcher(BlockPrefetcher): + """Block prefetcher using a local actor.""" + + def __init__(self): + self.prefetch_actor = self._get_or_create_actor_prefetcher() + + @staticmethod + def _get_or_create_actor_prefetcher() -> "ActorHandle": + node_id = ray.get_runtime_context().node_id + actor_name = f"dataset-block-prefetcher-{node_id}" + return _BlockPretcher.options( + scheduling_strategy=NodeAffinitySchedulingStrategy(node_id, soft=False), + name=actor_name, + namespace=PREFETCHER_ACTOR_NAMESPACE, + get_if_exists=True, + ).remote() + + def prefetch_blocks(self, blocks: ObjectRef[Block]): + self.prefetch_actor.prefetch.remote(*blocks) + +@ray.remote(num_cpus=0) +class _BlockPretcher: + """Helper actor that prefetches blocks asynchronously.""" + + def prefetch(self, *blocks) -> None: + pass diff --git a/python/ray/data/_internal/memory_tracing.py b/python/ray/data/_internal/memory_tracing.py index 2cdfb0d4defa5..762204ad1ccc5 100644 --- a/python/ray/data/_internal/memory_tracing.py +++ b/python/ray/data/_internal/memory_tracing.py @@ -99,6 +99,9 @@ def trace_dealloc(self, ref: List[ray.ObjectRef], loc: str, freed: bool): self.cur_mem -= size_bytes self.deallocated[ref] = self.allocated.pop(ref) self.deallocated[ref]["dealloc_loc"] = loc + if ref in self.deallocated: + # This object reference is already deallocated. + pass else: print(f"[mem_tracing] WARNING: allocation of {ref} was not traced!") else: From 3cafc47d9c99910516126c6409f310467c1affba Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 14:30:35 -0700 Subject: [PATCH 07/75] wip Signed-off-by: amogkam --- .../_internal/block_batching/interfaces.py | 73 +++++ .../_internal/block_batching/iter_batches.py | 260 ++++++++++++++++++ .../tests/block_batching/test_interfaces.py | 20 ++ .../tests/block_batching/test_iter_batches.py | 244 ++++++++++++++++ 4 files changed, 597 insertions(+) create mode 100644 python/ray/data/_internal/block_batching/interfaces.py create mode 100644 python/ray/data/_internal/block_batching/iter_batches.py create mode 100644 python/ray/data/tests/block_batching/test_interfaces.py create mode 100644 python/ray/data/tests/block_batching/test_iter_batches.py diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py new file mode 100644 index 0000000000000..ae77de1011663 --- /dev/null +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from typing import List, Optional + +import ray +from ray.types import ObjectRef +from ray.data.block import Block, DataBatch + +@dataclass +class LogicalBatch: + """A logical "batch" of data. + + This is not a fully created batch, but rather a conceptual batch + consisting of unresolved Block Object references. + + Attributes: + bundle_idx: The global index of this bundle so that downstream operations can + maintain ordering. + block_refs: The list of block object references for this batch. + blocks: The resolved blocks for this batch. This attribute can only be accessed + after calling `.resolve()` + starting_block_idx: The index of the first block where this batch starts. + ending_block_idx: The index of the last block where this batch ends. This can + also be None, meaning the entirety of the last block is included in this + batch. If this value is None, this allows us to eagerly clear the last + block in this batch after reading, since the last block is not included in + any other batches. + num_rows: The number of rows in this batch. This should be equivalent to the + provided batch size, except for the final batch. + """ + + batch_idx: int + block_refs: List[ObjectRef[Block]] + starting_block_idx: int + ending_block_idx: Optional[int] + num_rows: int + + def __post_init__(self): + self._resolved = False + + def resolve(self): + """Resolves the block_refs in this LogicalBatch.""" + if self._resolved: + return + self._resolved = True + self._blocks = ray.get(self.block_refs) + + @property + def blocks(self) -> List[Block]: + if not self._resolved: + raise RuntimeError("The resolved blocks for this logical batch can only be " + "accessed after calling `resolve`.") + return self._blocks + +@dataclass +class Batch: + """A batch of data. + + Attributes: + batch_idx: The global index of this batch so that downstream operations can + maintain ordering. + data: The batch of data. + logical_batch: The logical batch that was used to create this batch. + """ + batch_idx: int + data: DataBatch + logical_batch: LogicalBatch + +class BlockPrefetcher: + """Interface for prefetching blocks.""" + + def prefetch_blocks(self, blocks: ObjectRef[Block]): + """Prefetch the provided blocks to this node.""" + raise NotImplementedError \ No newline at end of file diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py new file mode 100644 index 0000000000000..fae4accf521c6 --- /dev/null +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -0,0 +1,260 @@ +import random +from typing import Any, Callable, Iterator, List, Optional, Tuple + +from ray.types import ObjectRef +from ray.data.block import Block, BlockMetadata, BlockAccessor, DataBatch +from ray.data._internal.block_batching.interfaces import Batch, LogicalBatch, BlockPrefetcher +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.memory_tracing import trace_deallocation + +def _bundle_block_refs_to_logical_batches( + block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], + batch_size: Optional[int], + drop_last: bool = False, +) -> Iterator[LogicalBatch]: + """Given an iterator of block object references, and their corresponding metadata, + bundles the block object references into groups of the provided `batch_size`. + + The output iterator returns an iterator over LogicalBatch objects. + + This function does not do any slicing or creation of actual batch objects. + """ + batch_buffer: List[ObjectRef[Block]] = [] + buffer_size = 0 + starting_index = 0 + + global_index = 0 + num_rows_in_last_block = 0 + + if batch_size is None: + for block_ref, metadata in block_ref_iterator: + yield LogicalBatch(global_index, [block_ref], 0, None, metadata.num_rows) + global_index += 1 + else: + while True: + if buffer_size < batch_size: + # Pull next block from iterator if current buffer is not enough to fill + # a batch. + try: + block_ref, metadata = next(block_ref_iterator) + except StopIteration: + break + batch_buffer.append(block_ref) + buffer_size += metadata.num_rows + num_rows_in_last_block = metadata.num_rows + + if buffer_size == batch_size: + # If equal to batch size, then yield the full buffer. + yield LogicalBatch( + global_index, batch_buffer, starting_index, None, buffer_size + ) + batch_buffer = [] + buffer_size = 0 + starting_index = 0 + num_rows_in_last_block = 0 + global_index += 1 + + if buffer_size > batch_size: + # If current buffer is greater than batch size, then yield part of the + # buffer, and carryover the remainder to the next batch. + num_rows_to_leave_behind = buffer_size - batch_size + ending_index = num_rows_in_last_block - num_rows_to_leave_behind + assert ending_index > 0, ending_index + yield LogicalBatch( + global_index, batch_buffer, starting_index, ending_index, batch_size + ) + global_index += 1 + # Carryover to next batch. + batch_buffer = [batch_buffer[-1]] + buffer_size = num_rows_to_leave_behind + starting_index = ending_index + + # Yield any leftover batches if necessary. + if buffer_size > 0 and not drop_last: + assert buffer_size < batch_size + yield LogicalBatch( + global_index, batch_buffer, starting_index, None, buffer_size + ) + global_index += 1 + +def _local_shuffle_logical_batches( + logical_batch_iterator: Iterator[LogicalBatch], + shuffle_buffer_min_size: int, + shuffle_seed: Optional[int] = None, +) -> Iterator[LogicalBatch]: + """Shuffles the logical batch iterator using a buffer of the provided size.""" + + if shuffle_seed is not None: + random.seed(shuffle_seed) + + shuffle_buffer: List[LogicalBatch] = [] + shuffle_buffer_size = 0 + global_counter = 0 + + for logical_batch in logical_batch_iterator: + shuffle_buffer.append(logical_batch) + shuffle_buffer_size += logical_batch.num_rows + + while shuffle_buffer_size >= shuffle_buffer_min_size: + output_batch = shuffle_buffer.pop( + random.randint(0, len(shuffle_buffer) - 1) + ) + output_batch.batch_idx = global_counter + yield output_batch + shuffle_buffer_size -= output_batch.num_rows + global_counter += 1 + + # Yield any leftover. + while len(shuffle_buffer) > 0: + output_batch = shuffle_buffer.pop(random.randint(0, len(shuffle_buffer) - 1)) + output_batch.batch_idx = global_counter + yield output_batch + global_counter += 1 + +def _prefetch_batches_locally( + logical_batch_iter: Iterator[LogicalBatch], + prefetcher: BlockPrefetcher, + num_batches_to_prefetch: int, +) -> Iterator[LogicalBatch]: + """Given an iterator of logical batches, returns an iterator over the same logical + batches, while prefetching `num_batches_to_prefetch` batches in advance. + + Args: + logical_batch_iter: An iterator over logical batches. + prefetcher: The prefetcher to use. + num_batches_to_prefetch: The number of batches to prefetch ahead of the + current batch during the scan. + """ + + def get_next_batches() -> Iterator[List[LogicalBatch]]: + """Return lists of logical batches corresponding to `num_batches_to_prefetch`""" + next_batches = [] + while True: + try: + next_batches.append(next(logical_batch_iter)) + if len(next_batches) == num_batches_to_prefetch: + yield next_batches + next_batches = [] + except StopIteration: + break + + if len(next_batches) > 0: + yield next_batches + + # Fetch the initial set of batches. + batch_iterator = get_next_batches() + try: + batches = next(batch_iterator) + except StopIteration: + return + + block_refs = [block_ref for batch in batches for block_ref in batch.block_refs] + prefetcher.prefetch_blocks(block_refs) + + for next_batches in batch_iterator: + # Prefetch the next batches. + block_refs = [ + block_ref for batch in next_batches for block_ref in batch.block_refs + ] + prefetcher.prefetch_blocks(block_refs) + + for batch in batches: + yield batch + + batches = next_batches + + # Yield the final set of batches. + for batch in batches: + yield batch + +def _construct_batch_from_logical_batch( + resolved_logical_batch_iter: Iterator[LogicalBatch], + ensure_copy: bool = False, +) -> Iterator[Tuple[int, Block]]: + """Given an iterator over logical batches, returns an iterator over actual + constructed batches. + + Args: + resolved_logical_batch_iter: An iterator over resolved logical batches. + stats: Dataset stats object used to store block batching time. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + + Returns: + An iterator over batch index and batches of the given size. + """ + + for logical_batch in resolved_logical_batch_iter: + output = DelegatingBlockBuilder() + slice_indices = [[0, None] for _ in range(len(logical_batch.blocks))] + if logical_batch.starting_block_idx > 0: + slice_indices[0][0] = logical_batch.starting_block_idx + if logical_batch.ending_block_idx is not None: + slice_indices[-1][1] = logical_batch.ending_block_idx + + for i, block in enumerate(logical_batch.blocks): + accessor = BlockAccessor.for_block(block) + slice_index = slice_indices[i] + output.add_block( + accessor.slice( + slice_index[0], + slice_index[1] + if slice_index[1] is not None + else accessor.num_rows(), + copy=False, + ) + ) + + batch = output.build() + assert len(batch) == logical_batch.num_rows, ( + len(batch), + logical_batch.num_rows, + ) + if ensure_copy: + # Need to ensure that the batch is a fresh copy. + batch = BlockAccessor.for_block(batch) + batch = batch.slice(0, batch.num_rows(), copy=True) + + yield Batch(logical_batch.batch_idx, batch, logical_batch) + +def _format_batches( + block_iter: Iterator[Batch], + batch_format: str, +) -> Iterator[Batch]: + """Given an iterator of blocks, returns an iterator of formatted batches. + + Args: + block_iter: An iterator over blocks. + batch_format: The batch format to use. + stats: An optional stats object to record formatting times. + + Returns: + An iterator over batch index and the formatted batch. + """ + for batch in block_iter: + formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format(batch_format) + batch.data = formatted_batch + yield batch + +def _collate( + batch_iter: Iterator[Batch], + collate_fn: Optional[Callable[[DataBatch], Any]], +) -> Iterator[Tuple[int, Any]]: + """Returns an iterator with the provided collate_fn applied to items of the batch + iterator. + + Args: + batch_iter: An iterator over formatted batches. + stats: An optional stats object to record collation time. + """ + for batch in batch_iter: + batch.data = collate_fn(batch.data) + yield batch + +def _eagerly_free_blocks(batch_iter: Iterator[Batch]): + """Eagerly free the block references in the batch iterator.""" + + for batch in batch_iter: + block_refs = batch.logical_batch.block_refs + for block_ref in block_refs: + trace_deallocation(block_ref, ) diff --git a/python/ray/data/tests/block_batching/test_interfaces.py b/python/ray/data/tests/block_batching/test_interfaces.py new file mode 100644 index 0000000000000..1f785890fc836 --- /dev/null +++ b/python/ray/data/tests/block_batching/test_interfaces.py @@ -0,0 +1,20 @@ +import pytest + +import ray +from ray.data._internal.block_batching.interfaces import LogicalBatch + +def test_logical_batch_resolves_blocks(ray_start_regular_shared): + block_refs = [ray.put(1), ray.put(2)] + logical_batch = LogicalBatch(batch_idx=0, block_refs=block_refs, starting_block_idx=0, ending_block_idx=None, num_rows=2) + + # Blocks should not be accessible before calling resolve(). + with pytest.raises(RuntimeError): + logical_batch.blocks + + logical_batch.resolve() + assert logical_batch.blocks == [1, 2] + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py new file mode 100644 index 0000000000000..b4bdc5b741217 --- /dev/null +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -0,0 +1,244 @@ +from copy import copy +import pytest +from typing import Iterator, List, Tuple + +import numpy as np +import pandas as pd +import pyarrow as pa + +from ray.data.block import Block, BlockMetadata +from ray.data._internal.block_batching.interfaces import Batch, LogicalBatch, BlockPrefetcher +from ray.data._internal.block_batching.iter_batches import _bundle_block_refs_to_logical_batches, _local_shuffle_logical_batches, _prefetch_batches_locally, _construct_batch_from_logical_batch, _format_batches, _collate + + +def block_generator( + num_rows: int, num_blocks: int +) -> Iterator[Tuple[Block, BlockMetadata]]: + for i in range(num_blocks): + yield pa.table({"foo": [i] * num_rows}), BlockMetadata( + num_rows=num_rows, + size_bytes=0, + schema=None, + input_files=[], + exec_stats=None, + ) + +def logical_batch_generator( + num_rows: int, num_blocks: int, batch_size: int = None +) -> Iterator[LogicalBatch]: + logical_batch_iter = _bundle_block_refs_to_logical_batches( + block_generator(num_rows=num_rows, num_blocks=num_blocks), batch_size=batch_size + ) + + # Force resolve to True for testing purposes. + for logical_batch in logical_batch_iter: + logical_batch._resolved = True + logical_batch._blocks = logical_batch.block_refs + yield logical_batch + + +def test_bundle_block_refs_to_logical_batches(): + # Case 1: `batch_size` is None. + num_blocks = 4 + num_rows_per_block = 2 + batch_size = None + block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) + block_refs = list(block_iter) + logical_batch_iter = _bundle_block_refs_to_logical_batches( + iter(block_refs), batch_size=batch_size + ) + logical_batches = list(logical_batch_iter) + assert logical_batches == [ + LogicalBatch(0, [block_refs[0][0]], 0, None, num_rows_per_block), + LogicalBatch(1, [block_refs[1][0]], 0, None, num_rows_per_block), + LogicalBatch(2, [block_refs[2][0]], 0, None, num_rows_per_block), + LogicalBatch(3, [block_refs[3][0]], 0, None, num_rows_per_block), + ] + + # Case 2: Multiple batches in a block (`batch_size` is 1). + num_blocks = 2 + num_rows_per_block = 2 + batch_size = 1 + block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) + block_refs = list(block_iter) + logical_batch_iter = _bundle_block_refs_to_logical_batches( + iter(block_refs), batch_size=batch_size + ) + logical_batches = list(logical_batch_iter) + assert logical_batches == [ + LogicalBatch(0, [block_refs[0][0]], 0, 1, batch_size), + LogicalBatch(1, [block_refs[0][0]], 1, None, batch_size), + LogicalBatch(2, [block_refs[1][0]], 0, 1, batch_size), + LogicalBatch(3, [block_refs[1][0]], 1, None, batch_size), + ] + + # Case 3: Multiple blocks in a batch (`batch_size` is 2) + num_blocks = 4 + num_rows_per_block = 1 + batch_size = 2 + block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) + block_refs = list(block_iter) + logical_batch_iter = _bundle_block_refs_to_logical_batches( + iter(block_refs), batch_size=batch_size + ) + logical_batches = list(logical_batch_iter) + assert logical_batches == [ + LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, None, batch_size), + LogicalBatch(1, [block_refs[2][0], block_refs[3][0]], 0, None, batch_size), + ] + + # Case 4: Batches overlap across multiple blocks unevenly + num_blocks = 4 + num_rows_per_block = 2 + batch_size = 3 + block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) + block_refs = list(block_iter) + logical_batch_iter = _bundle_block_refs_to_logical_batches( + iter(block_refs), batch_size=batch_size + ) + logical_batches = list(logical_batch_iter) + assert logical_batches == [ + LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, 1, batch_size), + LogicalBatch(1, [block_refs[1][0], block_refs[2][0]], 1, None, batch_size), + LogicalBatch(2, [block_refs[3][0]], 0, None, 2), # Leftover block. + ] + + # Case 5: Batches overlap across multiple blocks unevenly, dropping the last + # incomplete batch. + num_blocks = 4 + num_rows_per_block = 2 + batch_size = 3 + block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) + block_refs = list(block_iter) + logical_batch_iter = _bundle_block_refs_to_logical_batches( + iter(block_refs), batch_size=batch_size, drop_last=True + ) + logical_batches = list(logical_batch_iter) + assert logical_batches == [ + LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, 1, batch_size), + LogicalBatch(1, [block_refs[1][0], block_refs[2][0]], 1, None, batch_size), + ] + +def test_local_shuffle_logical_batches(): + # Case 1: Shuffle buffer min size is smaller than a batch. + # In this case, there is effectively no shuffling since the buffer + # never contains more than 1 batch. + shuffle_seed = 42 + num_blocks = 4 + num_rows_per_block = 2 + shuffle_buffer_min_size = 1 + logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) + shuffled_batches = list( + _local_shuffle_logical_batches( + iter(logical_batches), + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + ) + assert shuffled_batches == logical_batches + + # Case 2: Shuffle buffer min size is greater than a batch. + shuffle_seed = 42 + num_blocks = 4 + num_rows_per_block = 1 + shuffle_buffer_min_size = 2 + logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) + shuffled_batches = list( + _local_shuffle_logical_batches( + iter(logical_batches), + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + ) + + expected_output_ordering = [0, 1, 3, 2] + expected_output = [copy(logical_batches[i]) for i in expected_output_ordering] + for i in range(len(expected_output)): + expected_output[i].batch_idx = i + + assert shuffled_batches == expected_output + +@pytest.mark.parametrize("num_batches_to_prefetch", [1, 2]) +def test_prefetch_batches_locally(num_batches_to_prefetch): + class DummyPrefetcher(BlockPrefetcher): + def __init__(self): + self.windows = [] + + def prefetch_blocks(self, blocks: List[Block]): + self.windows.append(blocks) + + num_blocks = 10 + prefetcher = DummyPrefetcher() + logical_batches = list(logical_batch_generator(1, num_blocks)) + prefetch_batch_iter = _prefetch_batches_locally( + iter(logical_batches), + prefetcher=prefetcher, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + + # Test that we are actually prefetching. + # We should prefetch a new set of batches after the current set + # finishes. + sets_prefetched = 1 + output_batches = [] + for i, batch in enumerate(prefetch_batch_iter): + if i % num_batches_to_prefetch == 0: + # If all the batches are already prefetched, then skip the check. + if not sets_prefetched * num_batches_to_prefetch >= len(logical_batches): + assert len(prefetcher.windows) == sets_prefetched + 1 + sets_prefetched = len(prefetcher.windows) + output_batches.append(batch) + + windows = prefetcher.windows + assert all(len(window) == num_batches_to_prefetch for window in windows) + + # Check that the output iterator is the same as the input iterator. + assert output_batches == logical_batches + +@pytest.mark.parametrize("block_size", [1, 10]) +def test_construct_batch_from_logical_batch(block_size): + num_blocks = 5 + batch_size = 3 + logical_batches = list( + logical_batch_generator(block_size, num_blocks, batch_size=batch_size) + ) + + created_batches = list( + _construct_batch_from_logical_batch(iter(logical_batches)) + ) + + for i, batch in enumerate(created_batches): + assert i == batch.batch_idx + assert len(batch.data) == logical_batches[i].num_rows + +@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) +def test_format_batches(batch_format): + batches = [Batch(i, data[0], None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2))] + batch_iter = _format_batches(batches, batch_format=batch_format) + + for i, batch in enumerate(batch_iter): + assert batch.batch_idx == i + if batch_format == "pandas": + assert isinstance(batch.data, pd.DataFrame) + elif batch_format == "arrow": + assert isinstance(batch.data, pa.Table) + elif batch_format == "numpy": + assert isinstance(batch.data, dict) + assert isinstance(batch.data["foo"], np.ndarray) + + +def test_collate(): + def collate_fn(batch): + return pa.table({"bar": [1] * 2}) + + batches = [Batch(i, data[0], None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2))] + batch_iter = _collate(batches, collate_fn=collate_fn) + + for i, batch in enumerate(batch_iter): + assert batch.batch_idx == i + assert batch.data == pa.table({"bar": [1] * 2}) + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file From da9d7020b19393ca6d5792293c3e8429fcc37c92 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 14:34:50 -0700 Subject: [PATCH 08/75] update Signed-off-by: amogkam --- .../block_batching/block_batching.py | 6 ++- .../_internal/block_batching/interfaces.py | 14 +++++-- .../_internal/block_batching/iter_batches.py | 25 ++++++------ .../ray/data/_internal/block_batching/util.py | 4 ++ .../tests/block_batching/test_interfaces.py | 12 +++++- .../tests/block_batching/test_iter_batches.py | 39 ++++++++++++++----- 6 files changed, 73 insertions(+), 27 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 1c2e0b79a047d..2c4b380eb6a1d 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -5,7 +5,11 @@ import ray from ray.data._internal.block_batching.interfaces import BlockPrefetcher -from ray.data._internal.block_batching.util import _make_async_gen, WaitBlockPrefetcher, ActorBlockPrefetcher +from ray.data._internal.block_batching.util import ( + _make_async_gen, + WaitBlockPrefetcher, + ActorBlockPrefetcher, +) from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data._internal.memory_tracing import trace_deallocation diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index ae77de1011663..5165b272b4b40 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -5,6 +5,7 @@ from ray.types import ObjectRef from ray.data.block import Block, DataBatch + @dataclass class LogicalBatch: """A logical "batch" of data. @@ -47,27 +48,32 @@ def resolve(self): @property def blocks(self) -> List[Block]: if not self._resolved: - raise RuntimeError("The resolved blocks for this logical batch can only be " - "accessed after calling `resolve`.") + raise RuntimeError( + "The resolved blocks for this logical batch can only be " + "accessed after calling `resolve`." + ) return self._blocks + @dataclass class Batch: """A batch of data. - + Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. data: The batch of data. logical_batch: The logical batch that was used to create this batch. """ + batch_idx: int data: DataBatch logical_batch: LogicalBatch + class BlockPrefetcher: """Interface for prefetching blocks.""" def prefetch_blocks(self, blocks: ObjectRef[Block]): """Prefetch the provided blocks to this node.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index fae4accf521c6..30a25ac65219e 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -3,9 +3,13 @@ from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata, BlockAccessor, DataBatch -from ray.data._internal.block_batching.interfaces import Batch, LogicalBatch, BlockPrefetcher +from ray.data._internal.block_batching.interfaces import ( + Batch, + LogicalBatch, + BlockPrefetcher, +) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.memory_tracing import trace_deallocation + def _bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], @@ -77,6 +81,7 @@ def _bundle_block_refs_to_logical_batches( ) global_index += 1 + def _local_shuffle_logical_batches( logical_batch_iterator: Iterator[LogicalBatch], shuffle_buffer_min_size: int, @@ -111,6 +116,7 @@ def _local_shuffle_logical_batches( yield output_batch global_counter += 1 + def _prefetch_batches_locally( logical_batch_iter: Iterator[LogicalBatch], prefetcher: BlockPrefetcher, @@ -167,6 +173,7 @@ def get_next_batches() -> Iterator[List[LogicalBatch]]: for batch in batches: yield batch + def _construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, @@ -217,6 +224,7 @@ def _construct_batch_from_logical_batch( yield Batch(logical_batch.batch_idx, batch, logical_batch) + def _format_batches( block_iter: Iterator[Batch], batch_format: str, @@ -232,10 +240,13 @@ def _format_batches( An iterator over batch index and the formatted batch. """ for batch in block_iter: - formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format(batch_format) + formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) batch.data = formatted_batch yield batch + def _collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], @@ -250,11 +261,3 @@ def _collate( for batch in batch_iter: batch.data = collate_fn(batch.data) yield batch - -def _eagerly_free_blocks(batch_iter: Iterator[Batch]): - """Eagerly free the block references in the batch iterator.""" - - for batch in batch_iter: - block_refs = batch.logical_batch.block_refs - for block_ref in block_refs: - trace_deallocation(block_ref, ) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 63dd15e42fc9a..b9e08905ff2ee 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -93,7 +93,10 @@ def execute_computation(thread_index: int): output_queue.join() break + PREFETCHER_ACTOR_NAMESPACE = "ray.dataset" + + class WaitBlockPrefetcher(BlockPrefetcher): """Block prefetcher using ray.wait.""" @@ -124,6 +127,7 @@ def _get_or_create_actor_prefetcher() -> "ActorHandle": def prefetch_blocks(self, blocks: ObjectRef[Block]): self.prefetch_actor.prefetch.remote(*blocks) + @ray.remote(num_cpus=0) class _BlockPretcher: """Helper actor that prefetches blocks asynchronously.""" diff --git a/python/ray/data/tests/block_batching/test_interfaces.py b/python/ray/data/tests/block_batching/test_interfaces.py index 1f785890fc836..0f061596a4ea7 100644 --- a/python/ray/data/tests/block_batching/test_interfaces.py +++ b/python/ray/data/tests/block_batching/test_interfaces.py @@ -3,9 +3,16 @@ import ray from ray.data._internal.block_batching.interfaces import LogicalBatch + def test_logical_batch_resolves_blocks(ray_start_regular_shared): block_refs = [ray.put(1), ray.put(2)] - logical_batch = LogicalBatch(batch_idx=0, block_refs=block_refs, starting_block_idx=0, ending_block_idx=None, num_rows=2) + logical_batch = LogicalBatch( + batch_idx=0, + block_refs=block_refs, + starting_block_idx=0, + ending_block_idx=None, + num_rows=2, + ) # Blocks should not be accessible before calling resolve(). with pytest.raises(RuntimeError): @@ -14,7 +21,8 @@ def test_logical_batch_resolves_blocks(ray_start_regular_shared): logical_batch.resolve() assert logical_batch.blocks == [1, 2] + if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b4bdc5b741217..65ef6f167a21a 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -7,8 +7,19 @@ import pyarrow as pa from ray.data.block import Block, BlockMetadata -from ray.data._internal.block_batching.interfaces import Batch, LogicalBatch, BlockPrefetcher -from ray.data._internal.block_batching.iter_batches import _bundle_block_refs_to_logical_batches, _local_shuffle_logical_batches, _prefetch_batches_locally, _construct_batch_from_logical_batch, _format_batches, _collate +from ray.data._internal.block_batching.interfaces import ( + Batch, + LogicalBatch, + BlockPrefetcher, +) +from ray.data._internal.block_batching.iter_batches import ( + _bundle_block_refs_to_logical_batches, + _local_shuffle_logical_batches, + _prefetch_batches_locally, + _construct_batch_from_logical_batch, + _format_batches, + _collate, +) def block_generator( @@ -23,6 +34,7 @@ def block_generator( exec_stats=None, ) + def logical_batch_generator( num_rows: int, num_blocks: int, batch_size: int = None ) -> Iterator[LogicalBatch]: @@ -119,6 +131,7 @@ def test_bundle_block_refs_to_logical_batches(): LogicalBatch(1, [block_refs[1][0], block_refs[2][0]], 1, None, batch_size), ] + def test_local_shuffle_logical_batches(): # Case 1: Shuffle buffer min size is smaller than a batch. # In this case, there is effectively no shuffling since the buffer @@ -158,6 +171,7 @@ def test_local_shuffle_logical_batches(): assert shuffled_batches == expected_output + @pytest.mark.parametrize("num_batches_to_prefetch", [1, 2]) def test_prefetch_batches_locally(num_batches_to_prefetch): class DummyPrefetcher(BlockPrefetcher): @@ -195,6 +209,7 @@ def prefetch_blocks(self, blocks: List[Block]): # Check that the output iterator is the same as the input iterator. assert output_batches == logical_batches + @pytest.mark.parametrize("block_size", [1, 10]) def test_construct_batch_from_logical_batch(block_size): num_blocks = 5 @@ -202,18 +217,20 @@ def test_construct_batch_from_logical_batch(block_size): logical_batches = list( logical_batch_generator(block_size, num_blocks, batch_size=batch_size) ) - - created_batches = list( - _construct_batch_from_logical_batch(iter(logical_batches)) - ) + + created_batches = list(_construct_batch_from_logical_batch(iter(logical_batches))) for i, batch in enumerate(created_batches): assert i == batch.batch_idx assert len(batch.data) == logical_batches[i].num_rows + @pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) def test_format_batches(batch_format): - batches = [Batch(i, data[0], None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2))] + batches = [ + Batch(i, data[0], None) + for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) + ] batch_iter = _format_batches(batches, batch_format=batch_format) for i, batch in enumerate(batch_iter): @@ -231,14 +248,18 @@ def test_collate(): def collate_fn(batch): return pa.table({"bar": [1] * 2}) - batches = [Batch(i, data[0], None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2))] + batches = [ + Batch(i, data[0], None) + for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) + ] batch_iter = _collate(batches, collate_fn=collate_fn) for i, batch in enumerate(batch_iter): assert batch.batch_idx == i assert batch.data == pa.table({"bar": [1] * 2}) + if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) \ No newline at end of file + sys.exit(pytest.main(["-v", __file__])) From 5a07fd230e18627e4e955c114124661edb363a0e Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 15:15:07 -0700 Subject: [PATCH 09/75] more Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 39 +++++++++++++++++++ .../tests/block_batching/test_iter_batches.py | 27 +++++++++++++ 2 files changed, 66 insertions(+) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 30a25ac65219e..6cddca8579e64 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,3 +1,4 @@ +import heapq import random from typing import Any, Callable, Iterator, List, Optional, Tuple @@ -9,6 +10,7 @@ BlockPrefetcher, ) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.memory_tracing import trace_deallocation def _bundle_block_refs_to_logical_batches( @@ -261,3 +263,40 @@ def _collate( for batch in batch_iter: batch.data = collate_fn(batch.data) yield batch + + +def _trace_deallocation( + batch_iter: Iterator[Batch], eager_free: bool +) -> Iterator[Batch]: + """Trace deallocation of the underlying block references for each batch. + + Args: + batch_iter: An iterator over batches. + eager_free: Whether to eagerly free the object reference from the object store. + """ + for batch in batch_iter: + block_refs = batch.logical_batch.block_refs + for block_ref in block_refs: + trace_deallocation(block_ref, loc="iter_batches", free=eager_free) + yield batch + + +def _restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: + """Restores the original order of the provided `batch_iter` + + This function will yield items from `base_iterator` in the correct order based on + each batch's batch_idx. All indexes are expected to be unique. + + `batch_iter` is expected to not have any missing indexes. All indexes from 0 to len + (base_iterator) must be present. + """ + next_index_required = 0 + buffer: List[Batch] = [] + for batch in batch_iter: + heapq.heappush(buffer, (batch.batch_idx, batch)) + if buffer[0][0] == next_index_required: + yield heapq.heappop(buffer)[1] + next_index_required += 1 + + while len(buffer) > 0: + yield heapq.heappop(buffer)[1] diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 65ef6f167a21a..a9e2485742d8e 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,11 +1,13 @@ from copy import copy import pytest from typing import Iterator, List, Tuple +from unittest.mock import patch import numpy as np import pandas as pd import pyarrow as pa +import ray from ray.data.block import Block, BlockMetadata from ray.data._internal.block_batching.interfaces import ( Batch, @@ -19,6 +21,8 @@ _construct_batch_from_logical_batch, _format_batches, _collate, + _trace_deallocation, + _restore_from_original_order, ) @@ -259,6 +263,29 @@ def collate_fn(batch): assert batch.data == pa.table({"bar": [1] * 2}) +@patch.object(ray.data._internal.block_batching.iter_batches, "trace_deallocation") +@pytest.mark.parametrize("eager_free", [True, False]) +def test_trace_deallocation(mock, eager_free): + batches = [Batch(0, 0, LogicalBatch(0, [0], 0, None, 1))] + batch_iter = _trace_deallocation(iter(batches), eager_free=eager_free) + # Test that the underlying batch is not modified. + assert next(batch_iter) == batches[0] + mock.assert_called_once_with(0, loc="iter_batches", free=eager_free) + + +def test_restore_from_original_order(): + base_iterator = [ + Batch(1, None, None), + Batch(0, None, None), + Batch(3, None, None), + Batch(2, None, None), + ] + + ordered = list(_restore_from_original_order(iter(base_iterator))) + idx = [batch.batch_idx for batch in ordered] + assert idx == [0, 1, 2, 3] + + if __name__ == "__main__": import sys From 1beee62a174ff439d681879c8e5dff4fc0b285c8 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 15:39:55 -0700 Subject: [PATCH 10/75] wip Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 154 ++++++++++++++++++ python/ray/data/_internal/stats.py | 5 +- 2 files changed, 157 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 6cddca8579e64..03de9ab942d89 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -11,6 +11,160 @@ ) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.memory_tracing import trace_deallocation +from ray.data._internal.stats import DatasetStats + + +def iter_batches( + block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], + *, + stats: DatasetStats = None, + clear_block_after_read: bool = False, + batch_size: Optional[int] = None, + batch_format: str = "default", + drop_last: bool = False, + collate_fn: Optional[Callable[[DataBatch], Any]] = None, + shuffle_buffer_min_size: Optional[int] = None, + shuffle_seed: Optional[int] = None, + ensure_copy: bool = False, + prefetch_batches: int = 0, +) -> Iterator[DataBatch]: + """Create formatted batches of data from an iterator of block object references and + corresponding metadata. + + This takes a block iterator and creates batch_size batches, slicing, + unioning, shuffling, prefetching, and formatting blocks as needed. + + This is used by both Dataset.iter_batches() and DatasetPipeline.iter_batches() + + The algorithm is as follows: + + In a single async thread, do the following: + 1. Construct logical batches. This creates groupings of the block object references + based on the corresponding metadata.num_rows. The blocks are not resolved or sliced. + 2. If specified, locally shuffle the logical batches. + 3. Trigger local prefetching of the logical batches. + 4. Then, in a threadpool consisting of `prefetch_batches` threads: + 1. Resolve (i.e. call `ray.get()`) on the underlying block references for each + logical batch. + 2. Perform the necessary batch slicing to construct full batches. + 3. Format the batches to the provided batch format. + 4. Apply the collate function + 5. Fetch outputs from the threadpool, maintaining order of the batches. + + Args: + block_refs: An iterator over block object references and their corresponding + metadata. + clear_block_after_read: Whether to clear the block from object store + manually (i.e. without waiting for Python's automatic GC) after it + is read. Doing so will reclaim memory faster and hence reduce the + memory footprint. However, the caller has to ensure the safety, i.e. + the block will never be accessed again. + batch_size: Record batch size, or None to let the system pick. + batch_format: The format in which to return each batch. + Specify "default" to use the current block format (promoting + Arrow to pandas automatically), "pandas" to + select ``pandas.DataFrame`` or "pyarrow" to select + ``pyarrow.Table``. Default is "default". + drop_last: Whether to drop the last batch if it's incomplete. + collate_fn: A function to apply to each data batch before returning it. + shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a + local in-memory shuffle buffer, and this value will serve as the minimum + number of rows that must be in the local in-memory shuffle buffer in order + to yield a batch. + shuffle_seed: The seed to use for the local random shuffle. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + prefetch_batches: The number of batches to fetch ahead of the current batch to + process. If set to greater than 0, a separate thread will be used to fetch + the specified amount of formatted batches from blocks. This improves + performance for non-CPU bound UDFs, allowing batch fetching compute and + formatting to be overlapped with the UDF. Defaults to 0 (no prefetching + enabled). + + Returns: + An iterator over record batches. + """ + context = DatasetContext.get_current() + + if ( + prefetch_batches > 0 + and context.actor_prefetcher_enabled + and not ray.util.client.ray.is_connected() + ): + prefetcher = ActorBlockPrefetcher() + else: + prefetcher = WaitBlockPrefetcher() + + def _async_iter_batches(block_refs): + # Step 1: Construct logical batches based on the metadata. + batch_iter = _bundle_block_refs_to_logical_batches( + block_refs, batch_size=batch_size, drop_last=drop_last + ) + + # Step 2: Shuffle the logical batches if applicable. + if shuffle_buffer_min_size is not None: + batch_iter = _local_shuffle_logical_batches( + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + + # Step 3: Prefetch logical batches locally. + if prefetch_batches > 0: + batch_iter = _prefetch_batches_locally( + batch_iter, + prefetcher=prefetcher, + num_batches_to_prefetch=prefetch_batches, + stats=stats, + ) + + def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): + # Step 4.1: Resolve the blocks. + resolved_batch_iter = _resolve_blocks( + logical_batch_iter, + clear_block_after_read=clear_block_after_read, + stats=stats, + ) + + # Step 4.2: Slice the blocks to create the batch. + batch_iter = _construct_batch_from_logical_batch( + resolved_batch_iter, stats=stats, ensure_copy=ensure_copy + ) + + # Step 4.3: Format the batches. + formatted_batch_iter = _format_batches( + batch_iter, batch_format=batch_format, stats=stats + ) + + # Step 4.4: Apply the collate function if applicable. + if collate_fn is not None: + formatted_batch_iter = _collate( + formatted_batch_iter, collate_fn=collate_fn, stats=stats + ) + yield from formatted_batch_iter + + # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and + # collation. + if prefetch_batches > 0: + batch_iter = _make_async_gen( + batch_iter, fn=threadpool_computations, num_workers=prefetch_batches + ) + # Step 5: Make sure to preserve order from threadpool results. + yield from _preserve_order(batch_iter) + else: + # If no batch prefetching is specified, then don't use a threadpool. + batch_iter = threadpool_computations(batch_iter) + # Drop the index since ordering is already preserved as we are not using a + # threadpool. + for idx, batch in batch_iter: + yield batch + + # Run everything in a separate thread to not block the main thread when waiting + # for streaming results. + async_batch_iter = _make_async_gen( + block_refs, fn=_async_iter_batches, num_workers=1 + ) + + yield from async_batch_iter def _bundle_block_refs_to_logical_batches( diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 4a402e0f213c3..b1d2abfe48f21 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -217,10 +217,11 @@ def __init__( self.stats_uuid = stats_uuid # Iteration stats, filled out if the user iterates over the dataset. - self.iter_wait_s: Timer = Timer() self.iter_get_s: Timer = Timer() - self.iter_next_batch_s: Timer = Timer() + self.iter_create_batch_s: Timer = Timer() self.iter_format_batch_s: Timer = Timer() + self.iter_collate_batch_s: Timer = Timer() + self.iter_total_blocked_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_total_s: Timer = Timer() self.extra_metrics = {} From 7cd338d383df6fdb373165eaad6be855a4b97ccd Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 16:18:11 -0700 Subject: [PATCH 11/75] update Signed-off-by: amogkam --- .../data/_internal/block_batching/iter_batches.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 03de9ab942d89..cfd5cb450cfc9 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -54,6 +54,7 @@ def iter_batches( Args: block_refs: An iterator over block object references and their corresponding metadata. + stats: DatasetStats object to record timing and other statistics. clear_block_after_read: Whether to clear the block from object store manually (i.e. without waiting for Python's automatic GC) after it is read. Doing so will reclaim memory faster and hence reduce the @@ -166,6 +167,18 @@ def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): yield from async_batch_iter +def _batch_in_threadpool( + logical_batch_iterator: Iterator[LogicalBatch], + stats: DatasetStats, + clear_block_after_read: bool = False, + batch_format: str = "default", + collate_fn: Optional[Callable[[DataBatch], Any]] = None, + ensure_copy: bool = False, + prefetch_batches: int = 0, +): + """""" + + def _bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], From b74da64ca6548f6dfdc8b102a32a6335cb000454 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 16:40:10 -0700 Subject: [PATCH 12/75] wip Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 372 ++++++++++-------- 1 file changed, 208 insertions(+), 164 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index cfd5cb450cfc9..d57d19cce6fc4 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -14,169 +14,214 @@ from ray.data._internal.stats import DatasetStats -def iter_batches( - block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], - *, - stats: DatasetStats = None, - clear_block_after_read: bool = False, - batch_size: Optional[int] = None, - batch_format: str = "default", - drop_last: bool = False, - collate_fn: Optional[Callable[[DataBatch], Any]] = None, - shuffle_buffer_min_size: Optional[int] = None, - shuffle_seed: Optional[int] = None, - ensure_copy: bool = False, - prefetch_batches: int = 0, -) -> Iterator[DataBatch]: - """Create formatted batches of data from an iterator of block object references and - corresponding metadata. - - This takes a block iterator and creates batch_size batches, slicing, - unioning, shuffling, prefetching, and formatting blocks as needed. - - This is used by both Dataset.iter_batches() and DatasetPipeline.iter_batches() - - The algorithm is as follows: - - In a single async thread, do the following: - 1. Construct logical batches. This creates groupings of the block object references - based on the corresponding metadata.num_rows. The blocks are not resolved or sliced. - 2. If specified, locally shuffle the logical batches. - 3. Trigger local prefetching of the logical batches. - 4. Then, in a threadpool consisting of `prefetch_batches` threads: - 1. Resolve (i.e. call `ray.get()`) on the underlying block references for each - logical batch. - 2. Perform the necessary batch slicing to construct full batches. - 3. Format the batches to the provided batch format. - 4. Apply the collate function - 5. Fetch outputs from the threadpool, maintaining order of the batches. - - Args: - block_refs: An iterator over block object references and their corresponding - metadata. - stats: DatasetStats object to record timing and other statistics. - clear_block_after_read: Whether to clear the block from object store - manually (i.e. without waiting for Python's automatic GC) after it - is read. Doing so will reclaim memory faster and hence reduce the - memory footprint. However, the caller has to ensure the safety, i.e. - the block will never be accessed again. - batch_size: Record batch size, or None to let the system pick. - batch_format: The format in which to return each batch. - Specify "default" to use the current block format (promoting - Arrow to pandas automatically), "pandas" to - select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table``. Default is "default". - drop_last: Whether to drop the last batch if it's incomplete. - collate_fn: A function to apply to each data batch before returning it. - shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a - local in-memory shuffle buffer, and this value will serve as the minimum - number of rows that must be in the local in-memory shuffle buffer in order - to yield a batch. - shuffle_seed: The seed to use for the local random shuffle. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - prefetch_batches: The number of batches to fetch ahead of the current batch to - process. If set to greater than 0, a separate thread will be used to fetch - the specified amount of formatted batches from blocks. This improves - performance for non-CPU bound UDFs, allowing batch fetching compute and - formatting to be overlapped with the UDF. Defaults to 0 (no prefetching - enabled). - - Returns: - An iterator over record batches. - """ - context = DatasetContext.get_current() - - if ( - prefetch_batches > 0 - and context.actor_prefetcher_enabled - and not ray.util.client.ray.is_connected() - ): - prefetcher = ActorBlockPrefetcher() - else: - prefetcher = WaitBlockPrefetcher() - - def _async_iter_batches(block_refs): - # Step 1: Construct logical batches based on the metadata. - batch_iter = _bundle_block_refs_to_logical_batches( - block_refs, batch_size=batch_size, drop_last=drop_last - ) - - # Step 2: Shuffle the logical batches if applicable. - if shuffle_buffer_min_size is not None: - batch_iter = _local_shuffle_logical_batches( - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - - # Step 3: Prefetch logical batches locally. - if prefetch_batches > 0: - batch_iter = _prefetch_batches_locally( - batch_iter, - prefetcher=prefetcher, - num_batches_to_prefetch=prefetch_batches, - stats=stats, - ) - - def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): - # Step 4.1: Resolve the blocks. - resolved_batch_iter = _resolve_blocks( - logical_batch_iter, - clear_block_after_read=clear_block_after_read, - stats=stats, - ) - - # Step 4.2: Slice the blocks to create the batch. - batch_iter = _construct_batch_from_logical_batch( - resolved_batch_iter, stats=stats, ensure_copy=ensure_copy - ) - - # Step 4.3: Format the batches. - formatted_batch_iter = _format_batches( - batch_iter, batch_format=batch_format, stats=stats - ) - - # Step 4.4: Apply the collate function if applicable. - if collate_fn is not None: - formatted_batch_iter = _collate( - formatted_batch_iter, collate_fn=collate_fn, stats=stats - ) - yield from formatted_batch_iter - - # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and - # collation. - if prefetch_batches > 0: - batch_iter = _make_async_gen( - batch_iter, fn=threadpool_computations, num_workers=prefetch_batches - ) - # Step 5: Make sure to preserve order from threadpool results. - yield from _preserve_order(batch_iter) - else: - # If no batch prefetching is specified, then don't use a threadpool. - batch_iter = threadpool_computations(batch_iter) - # Drop the index since ordering is already preserved as we are not using a - # threadpool. - for idx, batch in batch_iter: - yield batch - - # Run everything in a separate thread to not block the main thread when waiting - # for streaming results. - async_batch_iter = _make_async_gen( - block_refs, fn=_async_iter_batches, num_workers=1 - ) - - yield from async_batch_iter - -def _batch_in_threadpool( - logical_batch_iterator: Iterator[LogicalBatch], - stats: DatasetStats, - clear_block_after_read: bool = False, - batch_format: str = "default", - collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ensure_copy: bool = False, - prefetch_batches: int = 0, -): - """""" +# def iter_batches( +# block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], +# *, +# stats: DatasetStats = None, +# clear_block_after_read: bool = False, +# batch_size: Optional[int] = None, +# batch_format: str = "default", +# drop_last: bool = False, +# collate_fn: Optional[Callable[[DataBatch], Any]] = None, +# shuffle_buffer_min_size: Optional[int] = None, +# shuffle_seed: Optional[int] = None, +# ensure_copy: bool = False, +# prefetch_batches: int = 0, +# ) -> Iterator[DataBatch]: +# """Create formatted batches of data from an iterator of block object references and +# corresponding metadata. + +# This takes a block iterator and creates batch_size batches, slicing, +# unioning, shuffling, prefetching, and formatting blocks as needed. + +# This is used by both Dataset.iter_batches() and DatasetPipeline.iter_batches() + +# The algorithm is as follows: + +# In a single async thread, do the following: +# 1. Construct logical batches. This creates groupings of the block object references +# based on the corresponding metadata.num_rows. The blocks are not resolved or sliced. +# 2. If specified, locally shuffle the logical batches. +# 3. Trigger local prefetching of the logical batches. +# 4. Then, in a threadpool consisting of `prefetch_batches` threads: +# 1. Resolve (i.e. call `ray.get()`) on the underlying block references for each +# logical batch. +# 2. Perform the necessary batch slicing to construct full batches. +# 3. Format the batches to the provided batch format. +# 4. Apply the collate function +# 5. Fetch outputs from the threadpool, maintaining order of the batches. + +# Args: +# block_refs: An iterator over block object references and their corresponding +# metadata. +# stats: DatasetStats object to record timing and other statistics. +# clear_block_after_read: Whether to clear the block from object store +# manually (i.e. without waiting for Python's automatic GC) after it +# is read. Doing so will reclaim memory faster and hence reduce the +# memory footprint. However, the caller has to ensure the safety, i.e. +# the block will never be accessed again. +# batch_size: Record batch size, or None to let the system pick. +# batch_format: The format in which to return each batch. +# Specify "default" to use the current block format (promoting +# Arrow to pandas automatically), "pandas" to +# select ``pandas.DataFrame`` or "pyarrow" to select +# ``pyarrow.Table``. Default is "default". +# drop_last: Whether to drop the last batch if it's incomplete. +# collate_fn: A function to apply to each data batch before returning it. +# shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a +# local in-memory shuffle buffer, and this value will serve as the minimum +# number of rows that must be in the local in-memory shuffle buffer in order +# to yield a batch. +# shuffle_seed: The seed to use for the local random shuffle. +# ensure_copy: Whether batches are always copied from the underlying base +# blocks (not zero-copy views). +# prefetch_batches: The number of batches to fetch ahead of the current batch to +# process. If set to greater than 0, a separate thread will be used to fetch +# the specified amount of formatted batches from blocks. This improves +# performance for non-CPU bound UDFs, allowing batch fetching compute and +# formatting to be overlapped with the UDF. Defaults to 0 (no prefetching +# enabled). + +# Returns: +# An iterator over record batches. +# """ +# context = DatasetContext.get_current() + +# if ( +# prefetch_batches > 0 +# and context.actor_prefetcher_enabled +# and not ray.util.client.ray.is_connected() +# ): +# prefetcher = ActorBlockPrefetcher() +# else: +# prefetcher = WaitBlockPrefetcher() + +# def _async_iter_batches(block_refs): +# # Step 1: Construct logical batches based on the metadata. +# batch_iter = _bundle_block_refs_to_logical_batches( +# block_refs, batch_size=batch_size, drop_last=drop_last +# ) + +# # Step 2: Shuffle the logical batches if applicable. +# if shuffle_buffer_min_size is not None: +# batch_iter = _local_shuffle_logical_batches( +# shuffle_buffer_min_size=shuffle_buffer_min_size, +# shuffle_seed=shuffle_seed, +# ) + +# # Step 3: Prefetch logical batches locally. +# if prefetch_batches > 0: +# batch_iter = _prefetch_batches_locally( +# batch_iter, +# prefetcher=prefetcher, +# num_batches_to_prefetch=prefetch_batches, +# stats=stats, +# ) + +# def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): +# # Step 4.1: Resolve the blocks. +# resolved_batch_iter = _resolve_blocks( +# logical_batch_iter, +# clear_block_after_read=clear_block_after_read, +# stats=stats, +# ) + +# # Step 4.2: Slice the blocks to create the batch. +# batch_iter = _construct_batch_from_logical_batch( +# resolved_batch_iter, stats=stats, ensure_copy=ensure_copy +# ) + +# # Step 4.3: Format the batches. +# formatted_batch_iter = _format_batches( +# batch_iter, batch_format=batch_format, stats=stats +# ) + +# # Step 4.4: Apply the collate function if applicable. +# if collate_fn is not None: +# formatted_batch_iter = _collate( +# formatted_batch_iter, collate_fn=collate_fn, stats=stats +# ) +# yield from formatted_batch_iter + +# # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and +# # collation. +# if prefetch_batches > 0: +# batch_iter = _make_async_gen( +# batch_iter, fn=threadpool_computations, num_workers=prefetch_batches +# ) +# # Step 5: Make sure to preserve order from threadpool results. +# yield from _preserve_order(batch_iter) +# else: +# # If no batch prefetching is specified, then don't use a threadpool. +# batch_iter = threadpool_computations(batch_iter) +# # Drop the index since ordering is already preserved as we are not using a +# # threadpool. +# for idx, batch in batch_iter: +# yield batch + +# # Run everything in a separate thread to not block the main thread when waiting +# # for streaming results. +# async_batch_iter = _make_async_gen( +# block_refs, fn=_async_iter_batches, num_workers=1 +# ) + +# yield from async_batch_iter + +# def _batch_in_threadpool( +# logical_batch_iterator: Iterator[LogicalBatch], +# stats: DatasetStats, +# clear_block_after_read: bool = False, +# batch_format: str = "default", +# collate_fn: Optional[Callable[[DataBatch], Any]] = None, +# ensure_copy: bool = False, +# prefetch_batches: int = 0, +# ): +# """Executes the batching, formatting, and collation logic in a threadpool. + +# Args: +# logical_batch_iterator: An iterator over logical batches. +# stats: DatasetStats object to record timing and other statistics. +# clear_block_after_read: Whether to clear the block from object store +# manually (i.e. without waiting for Python's automatic GC) after it +# is read. Doing so will reclaim memory faster and hence reduce the +# memory footprint. However, the caller has to ensure the safety, i.e. +# the block will never be accessed again. +# batch_format: The format in which to return each batch. +# Specify "default" to use the current block format (promoting +# Arrow to pandas automatically), "pandas" to +# select ``pandas.DataFrame`` or "pyarrow" to select +# ``pyarrow.Table``. Default is "default". +# collate_fn: A function to apply to each data batch before returning it. +# ensure_copy: Whether batches are always copied from the underlying base +# blocks (not zero-copy views). +# threadpool_size: The number of threads to use in the threadpool. +# """ + +# def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): +# # Step 4.1: Resolve the blocks. +# resolved_batch_iter = +# resolved_batch_iter = _resolve_blocks( +# logical_batch_iter, +# clear_block_after_read=clear_block_after_read, +# stats=stats, +# ) + +# # Step 4.2: Slice the blocks to create the batch. +# batch_iter = _construct_batch_from_logical_batch( +# resolved_batch_iter, stats=stats, ensure_copy=ensure_copy +# ) + +# # Step 4.3: Format the batches. +# formatted_batch_iter = _format_batches( +# batch_iter, batch_format=batch_format, stats=stats +# ) + +# # Step 4.4: Apply the collate function if applicable. +# if collate_fn is not None: +# formatted_batch_iter = _collate( +# formatted_batch_iter, collate_fn=collate_fn, stats=stats +# ) +# yield from formatted_batch_iter @@ -342,7 +387,6 @@ def get_next_batches() -> Iterator[List[LogicalBatch]]: for batch in batches: yield batch - def _construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, From c23bbd09fbe42d71aa46ce25605ab76f85bde759 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 16:44:31 -0700 Subject: [PATCH 13/75] add Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 7 +++++++ .../ray/data/tests/block_batching/test_iter_batches.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 6cddca8579e64..085d2108083ea 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -176,6 +176,13 @@ def get_next_batches() -> Iterator[List[LogicalBatch]]: yield batch +def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch]): + """Resolves the block references for each logical batch.""" + for logical_batch in logical_batch_iter: + logical_batch.resolve() + yield logical_batch + + def _construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index a9e2485742d8e..cd6552fa27249 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -18,6 +18,7 @@ _bundle_block_refs_to_logical_batches, _local_shuffle_logical_batches, _prefetch_batches_locally, + _resolve_logical_batch, _construct_batch_from_logical_batch, _format_batches, _collate, @@ -214,6 +215,14 @@ def prefetch_blocks(self, blocks: List[Block]): assert output_batches == logical_batches +@patch.object(ray.data._internal.block_batching.interfaces.LogicalBatch, "resolve") +def test_resolve_logical_batches(mock): + logical_batches = list(logical_batch_generator(1, 1)) + resolved_iter = _resolve_logical_batch(iter(logical_batches)) + assert next(resolved_iter) == logical_batches[0] + mock.assert_called_once() + + @pytest.mark.parametrize("block_size", [1, 10]) def test_construct_batch_from_logical_batch(block_size): num_blocks = 5 From 855f9fef5f9675220469bea5105f28d9b05acc3d Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 17:10:00 -0700 Subject: [PATCH 14/75] update tests Signed-off-by: amogkam --- .../tests/block_batching/test_iter_batches.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index cd6552fa27249..fbe9f52ad57d5 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -8,6 +8,7 @@ import pyarrow as pa import ray +from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata from ray.data._internal.block_batching.interfaces import ( Batch, @@ -29,9 +30,9 @@ def block_generator( num_rows: int, num_blocks: int -) -> Iterator[Tuple[Block, BlockMetadata]]: +) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: for i in range(num_blocks): - yield pa.table({"foo": [i] * num_rows}), BlockMetadata( + yield ray.put(pa.table({"foo": [i] * num_rows})), BlockMetadata( num_rows=num_rows, size_bytes=0, schema=None, @@ -46,15 +47,19 @@ def logical_batch_generator( logical_batch_iter = _bundle_block_refs_to_logical_batches( block_generator(num_rows=num_rows, num_blocks=num_blocks), batch_size=batch_size ) + return logical_batch_iter - # Force resolve to True for testing purposes. + +def resolved_logical_batch_generator( + num_rows: int, num_blocks: int, batch_size: int = None +): + logical_batch_iter = logical_batch_generator(num_rows, num_blocks, batch_size) for logical_batch in logical_batch_iter: - logical_batch._resolved = True - logical_batch._blocks = logical_batch.block_refs + logical_batch.resolve() yield logical_batch -def test_bundle_block_refs_to_logical_batches(): +def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): # Case 1: `batch_size` is None. num_blocks = 4 num_rows_per_block = 2 @@ -137,7 +142,7 @@ def test_bundle_block_refs_to_logical_batches(): ] -def test_local_shuffle_logical_batches(): +def test_local_shuffle_logical_batches(ray_start_regular_shared): # Case 1: Shuffle buffer min size is smaller than a batch. # In this case, there is effectively no shuffling since the buffer # never contains more than 1 batch. @@ -178,7 +183,7 @@ def test_local_shuffle_logical_batches(): @pytest.mark.parametrize("num_batches_to_prefetch", [1, 2]) -def test_prefetch_batches_locally(num_batches_to_prefetch): +def test_prefetch_batches_locally(ray_start_regular_shared, num_batches_to_prefetch): class DummyPrefetcher(BlockPrefetcher): def __init__(self): self.windows = [] @@ -215,20 +220,18 @@ def prefetch_blocks(self, blocks: List[Block]): assert output_batches == logical_batches -@patch.object(ray.data._internal.block_batching.interfaces.LogicalBatch, "resolve") -def test_resolve_logical_batches(mock): +def test_resolve_logical_batches(ray_start_regular_shared): logical_batches = list(logical_batch_generator(1, 1)) resolved_iter = _resolve_logical_batch(iter(logical_batches)) - assert next(resolved_iter) == logical_batches[0] - mock.assert_called_once() + assert next(resolved_iter).blocks == ray.get(logical_batches[0].block_refs) @pytest.mark.parametrize("block_size", [1, 10]) -def test_construct_batch_from_logical_batch(block_size): +def test_construct_batch_from_logical_batch(ray_start_regular_shared, block_size): num_blocks = 5 batch_size = 3 logical_batches = list( - logical_batch_generator(block_size, num_blocks, batch_size=batch_size) + resolved_logical_batch_generator(block_size, num_blocks, batch_size=batch_size) ) created_batches = list(_construct_batch_from_logical_batch(iter(logical_batches))) @@ -239,9 +242,9 @@ def test_construct_batch_from_logical_batch(block_size): @pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) -def test_format_batches(batch_format): +def test_format_batches(ray_start_regular_shared, batch_format): batches = [ - Batch(i, data[0], None) + Batch(i, ray.get(data[0]), None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] batch_iter = _format_batches(batches, batch_format=batch_format) @@ -257,12 +260,12 @@ def test_format_batches(batch_format): assert isinstance(batch.data["foo"], np.ndarray) -def test_collate(): +def test_collate(ray_start_regular_shared): def collate_fn(batch): return pa.table({"bar": [1] * 2}) batches = [ - Batch(i, data[0], None) + Batch(i, ray.get(data[0]), None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] batch_iter = _collate(batches, collate_fn=collate_fn) From 4b4851fa86607fad4520ce6671cfc43651ea3fdf Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 17:59:21 -0700 Subject: [PATCH 15/75] wip Signed-off-by: amogkam --- python/ray/data/_internal/batcher.py | 61 +++ .../block_batching/block_batching.py | 62 --- .../_internal/block_batching/iter_batches.py | 380 +++++++++--------- .../ray/data/_internal/block_batching/util.py | 8 +- .../data/_internal/dataset_iterator_impl.py | 83 ---- .../_internal/pipelined_dataset_iterator.py | 75 ---- .../ray/data/_internal/planner/map_batches.py | 2 - .../stream_split_dataset_iterator.py | 262 ------------ python/ray/data/context.py | 6 + python/ray/data/dataset.py | 14 +- python/ray/data/dataset_pipeline.py | 2 +- .../tests/block_batching/test_iter_batches.py | 74 ++++ python/ray/data/tests/test_batcher.py | 35 +- 13 files changed, 365 insertions(+), 699 deletions(-) delete mode 100644 python/ray/data/_internal/dataset_iterator_impl.py delete mode 100644 python/ray/data/_internal/pipelined_dataset_iterator.py delete mode 100644 python/ray/data/_internal/stream_split_dataset_iterator.py diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index d1358990dfcbd..a5c84cabfec84 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -319,3 +319,64 @@ def next_batch(self) -> Block: self._batch_head += batch_size # Yield the shuffled batch. return BlockAccessor.for_block(self._shuffle_buffer).take(batch_indices) + + +def _blocks_to_batches( + block_iter: Iterator[Block], + stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, + batch_size: Optional[int] = None, + drop_last: bool = False, + shuffle_buffer_min_size: Optional[int] = None, + shuffle_seed: Optional[int] = None, + ensure_copy: bool = False, +) -> Iterator[Block]: + """Given an iterator over blocks, returns an iterator over blocks + of the appropriate bacth size. + + If the shuffling configurations are specified, then the + output blocks contain shuffled data. + + Args: + block_iter: An iterator over blocks. + stats: Dataset stats object used to store block batching time. + batch_size: Record batch size, or None to let the system pick. + drop_last: Whether to drop the last batch if it's incomplete. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + + Returns: + An iterator over blocks of the given size that are potentially shuffled. + """ + if shuffle_buffer_min_size is not None: + batcher = ShufflingBatcher( + batch_size=batch_size, + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + else: + batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) + + def get_iter_next_batch_s_timer(): + return stats.iter_next_batch_s.timer() if stats else nullcontext() + + for block in block_iter: + batcher.add(block) + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch + + # Signal to the batcher that there are no more blocks to add. + batcher.done_adding() + + # Get any leftover batches in ShufflingBatcher. + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch + + # Get any remaining data. + if not drop_last and batcher.has_any(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch \ No newline at end of file diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 2c4b380eb6a1d..8554d2ef99b4e 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -279,68 +279,6 @@ def _prefetch_blocks( block_ref, "block_batching._prefetch_blocks", free=eager_free ) - -def _blocks_to_batches( - block_iter: Iterator[Block], - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, - batch_size: Optional[int] = None, - drop_last: bool = False, - shuffle_buffer_min_size: Optional[int] = None, - shuffle_seed: Optional[int] = None, - ensure_copy: bool = False, -) -> Iterator[Block]: - """Given an iterator over blocks, returns an iterator over blocks - of the appropriate bacth size. - - If the shuffling configurations are specified, then the - output blocks contain shuffled data. - - Args: - block_iter: An iterator over blocks. - stats: Dataset stats object used to store block batching time. - batch_size: Record batch size, or None to let the system pick. - drop_last: Whether to drop the last batch if it's incomplete. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - - Returns: - An iterator over blocks of the given size that are potentially shuffled. - """ - if shuffle_buffer_min_size is not None: - batcher = ShufflingBatcher( - batch_size=batch_size, - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - else: - batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) - - def get_iter_next_batch_s_timer(): - return stats.iter_next_batch_s.timer() if stats else nullcontext() - - for block in block_iter: - batcher.add(block) - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Signal to the batcher that there are no more blocks to add. - batcher.done_adding() - - # Get any leftover batches in ShufflingBatcher. - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Get any remaining data. - if not drop_last and batcher.has_any(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - def _format_batches( block_iter: Iterator[Block], batch_format: str, diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index fa089399e66cd..9c8f5ea53605b 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -6,12 +6,13 @@ import ray from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata, BlockAccessor, DataBatch +from ray.data.context import DatasetContext from ray.data._internal.block_batching.interfaces import ( Batch, LogicalBatch, BlockPrefetcher, ) -from ray.data._internal.block_batching.util import _calculate_ref_hits +from ray.data._internal.block_batching.util import _calculate_ref_hits, _make_async_gen, ActorBlockPrefetcher, WaitBlockPrefetcher from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetStats @@ -26,178 +27,153 @@ def nullcontext(enter_result=None): yield enter_result -# def iter_batches( -# block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], -# *, -# stats: DatasetStats = None, -# clear_block_after_read: bool = False, -# batch_size: Optional[int] = None, -# batch_format: str = "default", -# drop_last: bool = False, -# collate_fn: Optional[Callable[[DataBatch], Any]] = None, -# shuffle_buffer_min_size: Optional[int] = None, -# shuffle_seed: Optional[int] = None, -# ensure_copy: bool = False, -# prefetch_batches: int = 0, -# ) -> Iterator[DataBatch]: -# """Create formatted batches of data from an iterator of block object references and -# corresponding metadata. - -# This takes a block iterator and creates batch_size batches, slicing, -# unioning, shuffling, prefetching, and formatting blocks as needed. - -# This is used by both Dataset.iter_batches() and DatasetPipeline.iter_batches() - -# The algorithm is as follows: - -# In a single async thread, do the following: -# 1. Construct logical batches. This creates groupings of the block object references -# based on the corresponding metadata.num_rows. The blocks are not resolved or sliced. -# 2. If specified, locally shuffle the logical batches. -# 3. Trigger local prefetching of the logical batches. -# 4. Then, in a threadpool consisting of `prefetch_batches` threads: -# 1. Resolve (i.e. call `ray.get()`) on the underlying block references for each -# logical batch. -# 2. Perform the necessary batch slicing to construct full batches. -# 3. Format the batches to the provided batch format. -# 4. Apply the collate function -# 5. Fetch outputs from the threadpool, maintaining order of the batches. - -# Args: -# block_refs: An iterator over block object references and their corresponding -# metadata. -# stats: DatasetStats object to record timing and other statistics. -# clear_block_after_read: Whether to clear the block from object store -# manually (i.e. without waiting for Python's automatic GC) after it -# is read. Doing so will reclaim memory faster and hence reduce the -# memory footprint. However, the caller has to ensure the safety, i.e. -# the block will never be accessed again. -# batch_size: Record batch size, or None to let the system pick. -# batch_format: The format in which to return each batch. -# Specify "default" to use the current block format (promoting -# Arrow to pandas automatically), "pandas" to -# select ``pandas.DataFrame`` or "pyarrow" to select -# ``pyarrow.Table``. Default is "default". -# drop_last: Whether to drop the last batch if it's incomplete. -# collate_fn: A function to apply to each data batch before returning it. -# shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a -# local in-memory shuffle buffer, and this value will serve as the minimum -# number of rows that must be in the local in-memory shuffle buffer in order -# to yield a batch. -# shuffle_seed: The seed to use for the local random shuffle. -# ensure_copy: Whether batches are always copied from the underlying base -# blocks (not zero-copy views). -# prefetch_batches: The number of batches to fetch ahead of the current batch to -# process. If set to greater than 0, a separate thread will be used to fetch -# the specified amount of formatted batches from blocks. This improves -# performance for non-CPU bound UDFs, allowing batch fetching compute and -# formatting to be overlapped with the UDF. Defaults to 0 (no prefetching -# enabled). - -# Returns: -# An iterator over record batches. -# """ -# context = DatasetContext.get_current() - -# if ( -# prefetch_batches > 0 -# and context.actor_prefetcher_enabled -# and not ray.util.client.ray.is_connected() -# ): -# prefetcher = ActorBlockPrefetcher() -# else: -# prefetcher = WaitBlockPrefetcher() - -# def _async_iter_batches(block_refs): -# # Step 1: Construct logical batches based on the metadata. -# batch_iter = _bundle_block_refs_to_logical_batches( -# block_refs, batch_size=batch_size, drop_last=drop_last -# ) - -# # Step 2: Shuffle the logical batches if applicable. -# if shuffle_buffer_min_size is not None: -# batch_iter = _local_shuffle_logical_batches( -# shuffle_buffer_min_size=shuffle_buffer_min_size, -# shuffle_seed=shuffle_seed, -# ) - -# # Step 3: Prefetch logical batches locally. -# if prefetch_batches > 0: -# batch_iter = _prefetch_batches_locally( -# batch_iter, -# prefetcher=prefetcher, -# num_batches_to_prefetch=prefetch_batches, -# stats=stats, -# ) - -# def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): -# # Step 4.1: Resolve the blocks. -# resolved_batch_iter = _resolve_blocks( -# logical_batch_iter, -# clear_block_after_read=clear_block_after_read, -# stats=stats, -# ) - -# # Step 4.2: Slice the blocks to create the batch. -# batch_iter = _construct_batch_from_logical_batch( -# resolved_batch_iter, stats=stats, ensure_copy=ensure_copy -# ) - -# # Step 4.3: Format the batches. -# formatted_batch_iter = _format_batches( -# batch_iter, batch_format=batch_format, stats=stats -# ) - -# # Step 4.4: Apply the collate function if applicable. -# if collate_fn is not None: -# formatted_batch_iter = _collate( -# formatted_batch_iter, collate_fn=collate_fn, stats=stats -# ) -# yield from formatted_batch_iter - -# # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and -# # collation. -# if prefetch_batches > 0: -# batch_iter = _make_async_gen( -# batch_iter, fn=threadpool_computations, num_workers=prefetch_batches -# ) -# # Step 5: Make sure to preserve order from threadpool results. -# yield from _preserve_order(batch_iter) -# else: -# # If no batch prefetching is specified, then don't use a threadpool. -# batch_iter = threadpool_computations(batch_iter) -# # Drop the index since ordering is already preserved as we are not using a -# # threadpool. -# for idx, batch in batch_iter: -# yield batch - -# # Run everything in a separate thread to not block the main thread when waiting -# # for streaming results. -# async_batch_iter = _make_async_gen( -# block_refs, fn=_async_iter_batches, num_workers=1 -# ) - -# yield from async_batch_iter +def iter_batches( + block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], + *, + stats: Optional[DatasetStats] = None, + clear_block_after_read: bool = False, + batch_size: Optional[int] = None, + batch_format: str = "default", + drop_last: bool = False, + collate_fn: Optional[Callable[[DataBatch], Any]] = None, + shuffle_buffer_min_size: Optional[int] = None, + shuffle_seed: Optional[int] = None, + ensure_copy: bool = False, + prefetch_batches: int = 0, +) -> Iterator[DataBatch]: + """Create formatted batches of data from an iterator of block object references and + corresponding metadata. + + This takes a block iterator and creates batch_size batches, slicing, + unioning, shuffling, prefetching, and formatting blocks as needed. + + This is used by both Dataset.iter_batches() and DatasetPipeline.iter_batches() + + The algorithm is as follows: + + In a single async thread, do the following: + 1. Construct logical batches. This creates groupings of the block object references + based on the corresponding metadata.num_rows. The blocks are not resolved or sliced. + 2. If specified, locally shuffle the logical batches. + 3. Trigger local prefetching of the logical batches. + 4. Then, in a threadpool consisting of `prefetch_batches` threads: + 1. Resolve (i.e. call `ray.get()`) on the underlying block references for each + logical batch. + 2. Perform the necessary batch slicing to construct full batches. + 3. Format the batches to the provided batch format. + 4. Apply the collate function + 5. Trace deallocation and eagerly clear block references if necessary. + 6. Fetch outputs from the threadpool, maintaining order of the batches. + + Args: + block_refs: An iterator over block object references and their corresponding + metadata. + stats: DatasetStats object to record timing and other statistics. + clear_block_after_read: Whether to clear the block from object store + manually (i.e. without waiting for Python's automatic GC) after it + is read. Doing so will reclaim memory faster and hence reduce the + memory footprint. However, the caller has to ensure the safety, i.e. + the block will never be accessed again. + batch_size: Record batch size, or None to let the system pick. + batch_format: The format in which to return each batch. + Specify "default" to use the current block format (promoting + Arrow to pandas automatically), "pandas" to + select ``pandas.DataFrame`` or "pyarrow" to select + ``pyarrow.Table``. Default is "default". + drop_last: Whether to drop the last batch if it's incomplete. + collate_fn: A function to apply to each data batch before returning it. + shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a + local in-memory shuffle buffer, and this value will serve as the minimum + number of rows that must be in the local in-memory shuffle buffer in order + to yield a batch. + shuffle_seed: The seed to use for the local random shuffle. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + prefetch_batches: The number of batches to fetch ahead of the current batch to + process. If set to greater than 0, a separate thread will be used to fetch + the specified amount of formatted batches from blocks. This improves + performance for non-CPU bound UDFs, allowing batch fetching compute and + formatting to be overlapped with the UDF. Defaults to 0 (no prefetching + enabled). + + Returns: + An iterator over record batches. + """ + context = DatasetContext.get_current() + + if ( + prefetch_batches > 0 + and context.actor_prefetcher_enabled + and not ray.util.client.ray.is_connected() + ): + prefetcher = ActorBlockPrefetcher() + else: + prefetcher = WaitBlockPrefetcher() + + eager_free = clear_block_after_read and context.eager_free + + def _async_iter_batches(block_refs: Iterator[ObjectRef[Block]]) -> Iterator[DataBatch]: + # Step 1: Construct logical batches based on the metadata. + batch_iter = _bundle_block_refs_to_logical_batches( + block_refs, batch_size=batch_size, drop_last=drop_last + ) + + # Step 2: Shuffle the logical batches if applicable. + if shuffle_buffer_min_size is not None: + batch_iter = _local_shuffle_logical_batches( + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + + # Step 3: Prefetch logical batches locally. + if prefetch_batches > 0: + batch_iter = _prefetch_batches_locally( + batch_iter, + prefetcher=prefetcher, + num_batches_to_prefetch=prefetch_batches, + ) + + # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and + # collation. + batch_iter = _batch_in_threadpool(batch_iter, stats=stats, batch_format=batch_format, collate_fn=collate_fn, ensure_copy=ensure_copy, num_threadpool_workers=prefetch_batches) + + # Step 5: Trace deallocation + batch_iter = _trace_deallocation(batch_iter, eager_free=eager_free) + + # Step 6: Restore original order. + batch_iter: Iterator[Batch] = _restore_from_original_order(batch_iter) + + for batch in batch_iter: + yield batch.data + + # Run everything in a separate thread to not block the main thread when waiting + # for streaming results. + async_batch_iter = _make_async_gen( + block_refs, fn=_async_iter_batches, num_workers=1 + ) + + while True: + with stats.iter_total_blocked_s.timer() if stats else nullcontext(): + try: + next_batch = next(async_batch_iter) + except StopIteration: + break + with stats.iter_user_s.timer() if stats else nullcontext(): + yield next_batch def _batch_in_threadpool( logical_batch_iterator: Iterator[LogicalBatch], stats: DatasetStats, - clear_block_after_read: bool = False, batch_format: str = "default", collate_fn: Optional[Callable[[DataBatch], Any]] = None, ensure_copy: bool = False, - prefetch_batches: int = 0, -): + num_threadpool_workers: int = 0, +) -> Iterator[Batch]: """Executes the batching, formatting, and collation logic in a threadpool. Args: logical_batch_iterator: An iterator over logical batches. stats: DatasetStats object to record timing and other statistics. - clear_block_after_read: Whether to clear the block from object store - manually (i.e. without waiting for Python's automatic GC) after it - is read. Doing so will reclaim memory faster and hence reduce the - memory footprint. However, the caller has to ensure the safety, i.e. - the block will never be accessed again. batch_format: The format in which to return each batch. Specify "default" to use the current block format (promoting Arrow to pandas automatically), "pandas" to @@ -206,16 +182,16 @@ def _batch_in_threadpool( collate_fn: A function to apply to each data batch before returning it. ensure_copy: Whether batches are always copied from the underlying base blocks (not zero-copy views). - threadpool_size: The number of threads to use in the threadpool. + num_threadpool_workers: The number of threads to use in the threadpool. """ - def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): + def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]) -> Iterator[Batch]: # Step 4.1: Resolve the blocks. resolved_batch_iter = _resolve_logical_batch(logical_batch_iter, stats=stats) # Step 4.2: Slice the blocks to create the batch. batch_iter = _construct_batch_from_logical_batch( - resolved_batch_iter, stats=stats, ensure_copy=ensure_copy + resolved_batch_iter, ensure_copy=ensure_copy, stats=stats ) # Step 4.3: Format the batches. @@ -230,9 +206,7 @@ def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]): ) yield from formatted_batch_iter - return - - + return _make_async_gen(base_iterator=logical_batch_iterator, fn=threadpool_computations, num_workers=num_threadpool_workers) def _bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], @@ -405,7 +379,7 @@ def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch], stats: Op for logical_batch in logical_batch_iter: current_hit, current_miss, current_unknown = _calculate_ref_hits(logical_batch.block_refs) - hits += current_hit, + hits += current_hit misses += current_miss unknowns += current_unknown @@ -421,6 +395,7 @@ def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch], stats: Op def _construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, + stats: Optional[DatasetStats] = None ) -> Iterator[Tuple[int, Block]]: """Given an iterator over logical batches, returns an iterator over actual constructed batches. @@ -430,41 +405,43 @@ def _construct_batch_from_logical_batch( stats: Dataset stats object used to store block batching time. ensure_copy: Whether batches are always copied from the underlying base blocks (not zero-copy views). + stats: An optional stats object to record formatting times. Returns: An iterator over batch index and batches of the given size. """ for logical_batch in resolved_logical_batch_iter: - output = DelegatingBlockBuilder() - slice_indices = [[0, None] for _ in range(len(logical_batch.blocks))] - if logical_batch.starting_block_idx > 0: - slice_indices[0][0] = logical_batch.starting_block_idx - if logical_batch.ending_block_idx is not None: - slice_indices[-1][1] = logical_batch.ending_block_idx - - for i, block in enumerate(logical_batch.blocks): - accessor = BlockAccessor.for_block(block) - slice_index = slice_indices[i] - output.add_block( - accessor.slice( - slice_index[0], - slice_index[1] - if slice_index[1] is not None - else accessor.num_rows(), - copy=False, + with stats.iter_create_batch_s.timer() if stats else nullcontext(): + output = DelegatingBlockBuilder() + slice_indices = [[0, None] for _ in range(len(logical_batch.blocks))] + if logical_batch.starting_block_idx > 0: + slice_indices[0][0] = logical_batch.starting_block_idx + if logical_batch.ending_block_idx is not None: + slice_indices[-1][1] = logical_batch.ending_block_idx + + for i, block in enumerate(logical_batch.blocks): + accessor = BlockAccessor.for_block(block) + slice_index = slice_indices[i] + output.add_block( + accessor.slice( + slice_index[0], + slice_index[1] + if slice_index[1] is not None + else accessor.num_rows(), + copy=False, + ) ) - ) - batch = output.build() - assert len(batch) == logical_batch.num_rows, ( - len(batch), - logical_batch.num_rows, - ) - if ensure_copy: - # Need to ensure that the batch is a fresh copy. - batch = BlockAccessor.for_block(batch) - batch = batch.slice(0, batch.num_rows(), copy=True) + batch = output.build() + assert len(batch) == logical_batch.num_rows, ( + len(batch), + logical_batch.num_rows, + ) + if ensure_copy: + # Need to ensure that the batch is a fresh copy. + batch = BlockAccessor.for_block(batch) + batch = batch.slice(0, batch.num_rows(), copy=True) yield Batch(logical_batch.batch_idx, batch, logical_batch) @@ -472,6 +449,7 @@ def _construct_batch_from_logical_batch( def _format_batches( block_iter: Iterator[Batch], batch_format: str, + stats: Optional[DatasetStats] = None, ) -> Iterator[Batch]: """Given an iterator of blocks, returns an iterator of formatted batches. @@ -484,9 +462,10 @@ def _format_batches( An iterator over batch index and the formatted batch. """ for batch in block_iter: - formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) + with stats.iter_format_batch_s.timer() if stats else nullcontext(): + formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) batch.data = formatted_batch yield batch @@ -494,16 +473,19 @@ def _format_batches( def _collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], + stats: Optional[DatasetStats] = None ) -> Iterator[Tuple[int, Any]]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. Args: batch_iter: An iterator over formatted batches. + collate_fn: The collate_fn to execute. stats: An optional stats object to record collation time. """ for batch in batch_iter: - batch.data = collate_fn(batch.data) + with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + batch.data = collate_fn(batch.data) yield batch diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 180ad3709d282..5ac858dacaa2d 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -31,12 +31,17 @@ def _make_async_gen( Args: base_iterator: The iterator to asynchronously fetch from. fn: The function to run on the input iterator. - num_workers: The number of threads to use in the threadpool. + num_workers: The number of threads to use in the threadpool. Defaults to 1. Returns: An iterator with the same elements as the base_iterator. """ + # If no threadpool workers are specified, then don't use a threadpool. + if num_workers <= 0: + yield from fn(base_iterator) + return + def convert_to_threadsafe_iterator(base_iterator: Iterator[T]) -> Iterator[T]: class ThreadSafeIterator: def __init__(self, it): @@ -91,7 +96,6 @@ def execute_computation(thread_index: int): yield next_item output_queue.task_done() if num_threads_finished >= num_workers: - output_queue.join() break def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: diff --git a/python/ray/data/_internal/dataset_iterator_impl.py b/python/ray/data/_internal/dataset_iterator_impl.py deleted file mode 100644 index 082f45be9f123..0000000000000 --- a/python/ray/data/_internal/dataset_iterator_impl.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import TYPE_CHECKING, Optional, Union, Iterator, Callable, Any -import time -import warnings - -from ray.data.block import DataBatch -from ray.data.context import DatasetContext -from ray.data.dataset_iterator import DatasetIterator -from ray.data._internal.block_batching import batch_block_refs - -if TYPE_CHECKING: - import pyarrow - from ray.data import Dataset - - -class DatasetIteratorImpl(DatasetIterator): - def __init__( - self, - base_dataset: "Dataset", - ): - self._base_dataset = base_dataset - self._base_context = DatasetContext.get_current() - - def __repr__(self) -> str: - return f"DatasetIterator({self._base_dataset})" - - def iter_batches( - self, - *, - prefetch_blocks: int = 0, - batch_size: Optional[int] = 256, - batch_format: Optional[str] = "default", - drop_last: bool = False, - local_shuffle_buffer_size: Optional[int] = None, - local_shuffle_seed: Optional[int] = None, - _collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ) -> Iterator[DataBatch]: - - DatasetContext._set_current(self._base_context) - - ds = self._base_dataset - block_iterator, stats, executor = ds._plan.execute_to_iterator() - ds._current_executor = executor - time_start = time.perf_counter() - - yield from batch_block_refs( - block_iterator, - stats=stats, - prefetch_blocks=prefetch_blocks, - batch_size=batch_size, - batch_format=batch_format, - drop_last=drop_last, - collate_fn=_collate_fn, - shuffle_buffer_min_size=local_shuffle_buffer_size, - shuffle_seed=local_shuffle_seed, - ) - - stats.iter_total_s.add(time.perf_counter() - time_start) - - def stats(self) -> str: - return self._base_dataset.stats() - - def schema(self) -> Union[type, "pyarrow.lib.Schema"]: - return self._base_dataset.schema() - - def __getattr__(self, name): - if name == "_base_dataset": - raise AttributeError() - - if hasattr(self._base_dataset, name) and not name.startswith("_"): - # Warning for backwards compatibility. TODO: remove this method in 2.5. - warnings.warn( - "session.get_dataset_shard returns a ray.data.DatasetIterator " - "instead of a Dataset/DatasetPipeline as of Ray v2.3. " - "Use iter_torch_batches(), to_tf(), or iter_batches() to " - "iterate over one epoch. See " - "https://docs.ray.io/en/latest/data/api/dataset_iterator.html " - "for full DatasetIterator docs.", - stacklevel=4, - ) - - return getattr(self._base_dataset, name) - - raise AttributeError() diff --git a/python/ray/data/_internal/pipelined_dataset_iterator.py b/python/ray/data/_internal/pipelined_dataset_iterator.py deleted file mode 100644 index cdaf3f73e4d8a..0000000000000 --- a/python/ray/data/_internal/pipelined_dataset_iterator.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, Iterator -import warnings - -from ray.data.block import DataBatch -from ray.data.dataset_iterator import DatasetIterator - -if TYPE_CHECKING: - import pyarrow - from ray.data import DatasetPipeline - - -class PipelinedDatasetIterator(DatasetIterator): - def __init__( - self, - base_dataset_pipeline: "DatasetPipeline", - ): - self._base_dataset_pipeline = base_dataset_pipeline - self._epoch_iterator = None - - def __repr__(self) -> str: - return f"DatasetIterator({self._base_dataset_pipeline})" - - def _get_next_dataset(self) -> "DatasetPipeline": - if self._epoch_iterator is None: - self._epoch_iterator = self._base_dataset_pipeline.iter_epochs() - - ds = next(self._epoch_iterator) - return ds - - def iter_batches( - self, - *, - prefetch_blocks: int = 0, - batch_size: Optional[int] = 256, - batch_format: Optional[str] = "default", - drop_last: bool = False, - local_shuffle_buffer_size: Optional[int] = None, - local_shuffle_seed: Optional[int] = None, - _collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ) -> Iterator[DataBatch]: - ds = self._get_next_dataset() - return ds.iter_batches( - prefetch_blocks=prefetch_blocks, - batch_size=batch_size, - batch_format=batch_format, - drop_last=drop_last, - local_shuffle_buffer_size=local_shuffle_buffer_size, - local_shuffle_seed=local_shuffle_seed, - _collate_fn=_collate_fn, - ) - - def stats(self) -> str: - return self._base_dataset_pipeline.stats() - - def schema(self) -> Union[type, "pyarrow.lib.Schema"]: - return self._base_dataset_pipeline.schema() - - def __getattr__(self, name): - if name == "_base_dataset_pipeline": - raise AttributeError - - if hasattr(self._base_dataset_pipeline, name) and not name.startswith("_"): - # Warning for backwards compatibility. TODO: remove this method in 2.5. - warnings.warn( - "session.get_dataset_shard returns a ray.data.DatasetIterator " - "instead of a Dataset/DatasetPipeline as of Ray v2.3. " - "Use iter_torch_batches(), to_tf(), or iter_batches() to " - "iterate over one epoch. See " - "https://docs.ray.io/en/latest/data/api/dataset_iterator.html " - "for full DatasetIterator docs." - ) - - return getattr(self._base_dataset_pipeline, name) - else: - return super().__getattr__(name) diff --git a/python/ray/data/_internal/planner/map_batches.py b/python/ray/data/_internal/planner/map_batches.py index 711f275b599a8..9db2461ff74a7 100644 --- a/python/ray/data/_internal/planner/map_batches.py +++ b/python/ray/data/_internal/planner/map_batches.py @@ -11,7 +11,6 @@ def generate_map_batches_fn( batch_size: Optional[int] = DEFAULT_BATCH_SIZE, batch_format: Optional[str] = "default", - prefetch_batches: int = 0, zero_copy_batch: bool = False, ) -> Callable[[Iterator[Block], TaskContext, BatchUDF], Iterator[Block]]: """Generate function to apply the batch UDF to blocks.""" @@ -93,7 +92,6 @@ def process_next_batch(batch: DataBatch) -> Iterator[Block]: batch_size=batch_size, batch_format=batch_format, ensure_copy=not zero_copy_batch and batch_size is not None, - prefetch_batches=prefetch_batches, ) for batch in formatted_batch_iter: diff --git a/python/ray/data/_internal/stream_split_dataset_iterator.py b/python/ray/data/_internal/stream_split_dataset_iterator.py deleted file mode 100644 index 14ee2d882d9e0..0000000000000 --- a/python/ray/data/_internal/stream_split_dataset_iterator.py +++ /dev/null @@ -1,262 +0,0 @@ -import copy -import logging -import time -import threading -from typing import ( - List, - Dict, - Optional, - Iterator, - Callable, - Any, - Union, - TYPE_CHECKING, -) - -import ray - -from ray.data.dataset_iterator import DatasetIterator -from ray.data.block import Block, DataBatch -from ray.data.context import DatasetContext -from ray.data._internal.execution.streaming_executor import StreamingExecutor -from ray.data._internal.execution.legacy_compat import ( - execute_to_legacy_bundle_iterator, -) -from ray.data._internal.block_batching import batch_block_refs -from ray.data._internal.execution.operators.output_splitter import OutputSplitter -from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle -from ray.types import ObjectRef -from ray.util.debug import log_once -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - -if TYPE_CHECKING: - import pyarrow - from ray.data import Dataset - -logger = logging.getLogger(__name__) - - -BLOCKED_CLIENT_WARN_TIMEOUT = 30 - - -class StreamSplitDatasetIterator(DatasetIterator): - """Implements a collection of iterators over a shared data stream.""" - - @staticmethod - def create( - base_dataset: "Dataset", - n: int, - equal: bool, - locality_hints: Optional[List[NodeIdStr]], - ) -> List["StreamSplitDatasetIterator"]: - """Create a split iterator from the given base Dataset and options. - - See also: `Dataset.streaming_split`. - """ - ctx = DatasetContext.get_current() - - # To avoid deadlock, the concurrency on this actor must be set to at least `n`. - coord_actor = SplitCoordinator.options( - max_concurrency=n, - scheduling_strategy=NodeAffinitySchedulingStrategy( - ray.get_runtime_context().get_node_id(), soft=False - ), - ).remote(ctx, base_dataset, n, equal, locality_hints) - - return [ - StreamSplitDatasetIterator(base_dataset, coord_actor, i) for i in range(n) - ] - - def __init__( - self, - base_dataset: "Dataset", - coord_actor: ray.actor.ActorHandle, - output_split_idx: int, - ): - self._base_dataset = base_dataset - self._coord_actor = coord_actor - self._output_split_idx = output_split_idx - - def iter_batches( - self, - *, - prefetch_blocks: int = 0, - batch_size: int = 256, - batch_format: Optional[str] = "default", - drop_last: bool = False, - local_shuffle_buffer_size: Optional[int] = None, - local_shuffle_seed: Optional[int] = None, - _collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ) -> Iterator[DataBatch]: - """Implements DatasetIterator.""" - - def gen_blocks() -> Iterator[ObjectRef[Block]]: - cur_epoch = ray.get( - self._coord_actor.start_epoch.remote(self._output_split_idx) - ) - future: ObjectRef[ - Optional[ObjectRef[Block]] - ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx) - while True: - block_ref: Optional[ObjectRef[Block]] = ray.get(future) - if not block_ref: - break - else: - future = self._coord_actor.get.remote( - cur_epoch, self._output_split_idx - ) - yield block_ref - - yield from batch_block_refs( - gen_blocks(), - stats=None, - prefetch_blocks=prefetch_blocks, - batch_size=batch_size, - batch_format=batch_format, - drop_last=drop_last, - collate_fn=_collate_fn, - shuffle_buffer_min_size=local_shuffle_buffer_size, - shuffle_seed=local_shuffle_seed, - ) - - def stats(self) -> str: - """Implements DatasetIterator.""" - return self._base_dataset.stats() - - def schema(self) -> Union[type, "pyarrow.lib.Schema"]: - """Implements DatasetIterator.""" - return self._base_dataset.schema() - - -@ray.remote(num_cpus=0) -class SplitCoordinator: - """Coordinator actor for routing blocks to output splits. - - This actor runs a streaming executor locally on its main thread. Clients can - retrieve results via actor calls running on other threads. - """ - - def __init__( - self, - ctx: DatasetContext, - dataset: "Dataset", - n: int, - equal: bool, - locality_hints: Optional[List[NodeIdStr]], - ): - # Automatically set locality with output to the specified location hints. - if locality_hints: - ctx.execution_options.locality_with_output = locality_hints - logger.info(f"Auto configuring locality_with_output={locality_hints}") - - DatasetContext._set_current(ctx) - self._base_dataset = dataset - self._n = n - self._equal = equal - self._locality_hints = locality_hints - self._lock = threading.RLock() - - # Guarded by self._lock. - self._next_bundle: Dict[int, RefBundle] = {} - self._unfinished_clients_in_epoch = n - self._cur_epoch = -1 - - def gen_epochs(): - while True: - executor = StreamingExecutor(copy.deepcopy(ctx.execution_options)) - - def add_split_op(dag): - return OutputSplitter(dag, n, equal, locality_hints) - - output_iterator = execute_to_legacy_bundle_iterator( - executor, - dataset._plan, - True, - dataset._plan._dataset_uuid, - dag_rewrite=add_split_op, - ) - yield output_iterator - - self._next_epoch = gen_epochs() - self._output_iterator = None - - def start_epoch(self, split_idx: int) -> str: - """Called to start an epoch. - - Returns: - UUID for the epoch, which must be used when accessing results via get(). - """ - - # Wait for all clients to arrive at the barrier before starting a new epoch. - epoch_id = self._barrier(split_idx) - return epoch_id - - def get(self, epoch_id: int, output_split_idx: int) -> Optional[ObjectRef[Block]]: - """Blocking get operation. - - This is intended to be called concurrently from multiple clients. - """ - - if epoch_id != self._cur_epoch: - raise ValueError( - "Invalid iterator: the datastream has moved on to another epoch." - ) - - try: - # Ensure there is at least one bundle. - with self._lock: - if output_split_idx in self._next_bundle: - next_bundle = self._next_bundle[output_split_idx] - else: - next_bundle = None - - # Fetch next bundle if needed. - if next_bundle is None: - # This is a BLOCKING call, so do it outside the lock. - next_bundle = self._output_iterator.get_next(output_split_idx) - - block = next_bundle.blocks.pop()[0] - - # Accumulate any remaining blocks in next_bundle map as needed. - with self._lock: - self._next_bundle[output_split_idx] = next_bundle - if not next_bundle.blocks: - del self._next_bundle[output_split_idx] - - return block - except StopIteration: - return None - - def _barrier(self, split_idx: int) -> int: - """Arrive and block until the start of the given epoch.""" - - # Decrement and await all clients to arrive here. - with self._lock: - starting_epoch = self._cur_epoch - self._unfinished_clients_in_epoch -= 1 - - start_time = time.time() - while ( - self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0 - ): - if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT: - if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"): - logger.warning( - f"StreamSplitDatasetIterator(epoch={starting_epoch}, " - f"split={split_idx}) blocked waiting on other clients " - f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All " - "clients must read from the DatasetIterator splits at " - "the same time. This warning will not be printed again " - "for this epoch." - ) - time.sleep(0.1) - - # Advance to the next epoch. - with self._lock: - if self._cur_epoch == starting_epoch: - self._cur_epoch += 1 - self._unfinished_clients_in_epoch = self._n - self._output_iterator = next(self._next_epoch) - - assert self._output_iterator is not None - return starting_epoch + 1 diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 6630cc02bb291..4c4bb417843c0 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -107,6 +107,9 @@ # Set this env var to enable distributed tqdm (experimental). DEFAULT_USE_RAY_TQDM = bool(int(os.environ.get("RAY_TQDM", "1"))) +# Set this to True to use the legacy iter_batches codepath prior to 2.4. +DEFAULT_USE_LEGACY_ITER_BATCHES = False + # Use this to prefix important warning messages for the user. WARN_PREFIX = "⚠️ " @@ -152,6 +155,7 @@ def __init__( optimizer_enabled: bool, execution_options: "ExecutionOptions", use_ray_tqdm: bool, + use_legacy_iter_batches: bool, ): """Private constructor (use get_current() instead).""" self.block_splitting_enabled = block_splitting_enabled @@ -182,6 +186,7 @@ def __init__( # TODO: expose execution options in Dataset public APIs. self.execution_options = execution_options self.use_ray_tqdm = use_ray_tqdm + self.use_legacy_iter_batches = use_legacy_iter_batches @staticmethod def get_current() -> "DatasetContext": @@ -228,6 +233,7 @@ def get_current() -> "DatasetContext": optimizer_enabled=DEFAULT_OPTIMIZER_ENABLED, execution_options=ExecutionOptions(), use_ray_tqdm=DEFAULT_USE_RAY_TQDM, + use_legacy_iter_batches=DEFAULT_USE_LEGACY_ITER_BATCHES, ) return _default_context diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4830ebdef4f42..b6ea124df92c6 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -51,8 +51,8 @@ from ray.data._internal.planner.write import generate_write_fn from ray.data.dataset_iterator import DatasetIterator from ray.data._internal.block_list import BlockList -from ray.data._internal.dataset_iterator_impl import DatasetIteratorImpl -from ray.data._internal.stream_split_dataset_iterator import StreamSplitDatasetIterator +from ray.data._internal.dataset_iterator.dataset_iterator_impl import DatasetIteratorImpl +from ray.data._internal.dataset_iterator.stream_split_dataset_iterator import StreamSplitDatasetIterator from ray.data._internal.compute import ( ActorPoolStrategy, CallableClass, @@ -380,7 +380,6 @@ def map_batches( batch_size: Optional[Union[int, Literal["default"]]] = "default", compute: Optional[Union[str, ComputeStrategy]] = None, batch_format: Optional[str] = "default", - prefetch_batches: int = 0, zero_copy_batch: bool = False, fn_args: Optional[Iterable[Any]] = None, fn_kwargs: Optional[Dict[str, Any]] = None, @@ -543,14 +542,6 @@ def map_batches( ``Dict[str, numpy.ndarray]`` for tabular datasets, or None to return the underlying block exactly as is with no additional formatting. The default is "default". - prefetch_batches: The number of batches to fetch ahead of the current batch - to process. If set to greater than 0, a separate thread will be used - to fetch the specified amount of formatted batches from blocks. This - improves performance for non-CPU bound UDFs, allowing batch fetching - compute and formatting to be overlapped with the UDF. Defaults to 0 (no - prefetching enabled.) Increasing the number of batches to prefetch can - result in higher throughput, at the expense of requiring more heap - memory to buffer the batches. zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only batches. If this is ``True`` and no copy is required for the ``batch_format`` conversion, the batch will be a zero-copy, read-only @@ -650,7 +641,6 @@ def map_batches( transform_fn = generate_map_batches_fn( batch_size=batch_size, batch_format=batch_format, - prefetch_batches=prefetch_batches, zero_copy_batch=zero_copy_batch, ) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index dea973a43eac3..c07e6c177b8ab 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -28,7 +28,7 @@ PipelineExecutor, PipelineSplitExecutorCoordinator, ) -from ray.data._internal.pipelined_dataset_iterator import PipelinedDatasetIterator +from ray.data._internal.dataset_iterator.pipelined_dataset_iterator import PipelinedDatasetIterator from ray.data._internal.plan import ExecutionPlan from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data.block import BatchUDF, Block, DataBatch, KeyFn, RowUDF diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index fbe9f52ad57d5..8e02f96f52699 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,5 +1,6 @@ from copy import copy import pytest +import time from typing import Iterator, List, Tuple from unittest.mock import patch @@ -16,6 +17,7 @@ BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( + iter_batches, _bundle_block_refs_to_logical_batches, _local_shuffle_logical_batches, _prefetch_batches_locally, @@ -297,6 +299,78 @@ def test_restore_from_original_order(): idx = [batch.batch_idx for batch in ordered] assert idx == [0, 1, 2, 3] +# Test for 3 cases +# 1. Batch size is less than block size +# 2. Batch size is more than block size +# 3. Block size is not divisble by batch size +@pytest.mark.parametrize("batch_size", [1, 4, 3]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_iter_batches_e2e( + ray_start_regular_shared, batch_size, drop_last +): + + def collate_fn(batch: pd.DataFrame): + return batch + 1 + + block_refs_iter = block_generator(num_blocks=4, num_rows=2) + + output_batches = iter_batches( + block_refs_iter, + batch_size=batch_size, + batch_format="pandas", + collate_fn=collate_fn, + drop_last=drop_last, + ) + + output_batches = list(output_batches) + + assert len(output_batches) > 0 + for df in output_batches: + # Check batch formatting. + assert isinstance(df, pd.DataFrame) + # Check batch size. + if batch_size == 3 and not drop_last: + assert len(df) in {2, 3} + else: + assert len(df) == batch_size + + concat_df = pd.concat(output_batches) + # Test that collate_fn is applied. + assert concat_df["foo"].iloc[0] == 1 + # Make sure order is preserved. + for i in range(len(concat_df) - 1): + assert concat_df["foo"].iloc[i + 1] >= concat_df["foo"].iloc[i] + + +def test_iter_batches_e2e_async(ray_start_regular_shared): + """We add time.sleep in 3 places: + 1. In the base generator to simulate streaming executor blocking on next results. + 2. In the collate_fn to simulate expensive slicing/formatting/collation + 3. In the user thread to simulate training. + """ + def collate_fn(batch): + time.sleep(2) + return batch + + block_refs_iter = block_generator(num_blocks=20, num_rows=2) + start_time = time.time() + output_batches = iter_batches( + block_refs_iter, batch_size=None, collate_fn=collate_fn, prefetch_batches=4 + ) + batches = [] + for batch in output_batches: + time.sleep(1.5) + batches.append(batch) + end_time = time.time() + + # 20 batches, 1.5 second sleep. Should be less than 45 seconds, even with some + # overhead. + # If there was no overlap, then we would expect this to take at least 20*2.5 = 50 + assert end_time - start_time < 45, end_time - start_time + + assert len(batches) == 20 + assert all(len(batch) == 2 for batch in batches) + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_batcher.py b/python/ray/data/tests/test_batcher.py index 20769fb4f85ea..2262a2c30bc91 100644 --- a/python/ray/data/tests/test_batcher.py +++ b/python/ray/data/tests/test_batcher.py @@ -2,7 +2,7 @@ import pyarrow as pa -from ray.data._internal.batcher import ShufflingBatcher +from ray.data._internal.batcher import ShufflingBatcher, _blocks_to_batches def gen_block(num_rows): @@ -127,6 +127,39 @@ def next_and_check( ) +@pytest.mark.parametrize("block_size", [1, 10]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_blocks_to_batches(block_size, drop_last): + def block_generator(num_rows, num_blocks): + for _ in range(num_blocks): + yield gen_block(num_rows) + + num_blocks = 5 + block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + + batch_size = 3 + batch_iter = _blocks_to_batches( + block_iter, batch_size=batch_size, drop_last=drop_last + ) + + if drop_last: + for batch in batch_iter: + assert len(batch) == batch_size + else: + full_batches = 0 + leftover_batches = 0 + + dataset_size = block_size * num_blocks + for batch in batch_iter: + if len(batch) == batch_size: + full_batches += 1 + if len(batch) == (dataset_size % batch_size): + leftover_batches += 1 + + assert leftover_batches == 1 + assert full_batches == (dataset_size // batch_size) + + if __name__ == "__main__": import sys From dcfdb065f6541b1ed4f0552e54c794d8ca82e540 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 18:05:16 -0700 Subject: [PATCH 16/75] comments Signed-off-by: amogkam --- .../block_batching/block_batching.py | 6 +-- .../_internal/block_batching/interfaces.py | 6 +-- .../_internal/block_batching/iter_batches.py | 26 +++++----- .../tests/block_batching/test_iter_batches.py | 48 +++++++++---------- 4 files changed, 44 insertions(+), 42 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 2c4b380eb6a1d..28093f225389b 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -36,7 +36,7 @@ def batch_block_refs( prefetch_blocks: int = 0, clear_block_after_read: bool = False, batch_size: Optional[int] = None, - batch_format: str = "default", + batch_format: Optional[str] = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], Any]] = None, shuffle_buffer_min_size: Optional[int] = None, @@ -128,7 +128,7 @@ def batch_blocks( *, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, batch_size: Optional[int] = None, - batch_format: str = "default", + batch_format: Optional[str] = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], DataBatch]] = None, shuffle_buffer_min_size: Optional[int] = None, @@ -343,7 +343,7 @@ def get_iter_next_batch_s_timer(): def _format_batches( block_iter: Iterator[Block], - batch_format: str, + batch_format: Optional[str], stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[DataBatch]: """Given an iterator of blocks, returns an iterator of formatted batches. diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 5165b272b4b40..65332656ddc02 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Any, List, Optional import ray from ray.types import ObjectRef -from ray.data.block import Block, DataBatch +from ray.data.block import Block @dataclass @@ -67,7 +67,7 @@ class Batch: """ batch_idx: int - data: DataBatch + data: Any logical_batch: LogicalBatch diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 085d2108083ea..6c7a14c7b2a9e 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -13,7 +13,7 @@ from ray.data._internal.memory_tracing import trace_deallocation -def _bundle_block_refs_to_logical_batches( +def bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], batch_size: Optional[int], drop_last: bool = False, @@ -84,7 +84,7 @@ def _bundle_block_refs_to_logical_batches( global_index += 1 -def _local_shuffle_logical_batches( +def local_shuffle_logical_batches( logical_batch_iterator: Iterator[LogicalBatch], shuffle_buffer_min_size: int, shuffle_seed: Optional[int] = None, @@ -119,7 +119,7 @@ def _local_shuffle_logical_batches( global_counter += 1 -def _prefetch_batches_locally( +def prefetch_batches_locally( logical_batch_iter: Iterator[LogicalBatch], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, @@ -176,17 +176,19 @@ def get_next_batches() -> Iterator[List[LogicalBatch]]: yield batch -def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch]): +def resolve_logical_batch( + logical_batch_iter: Iterator[LogicalBatch], +) -> Iterator[LogicalBatch]: """Resolves the block references for each logical batch.""" for logical_batch in logical_batch_iter: logical_batch.resolve() yield logical_batch -def _construct_batch_from_logical_batch( +def construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, -) -> Iterator[Tuple[int, Block]]: +) -> Iterator[Batch]: """Given an iterator over logical batches, returns an iterator over actual constructed batches. @@ -234,9 +236,9 @@ def _construct_batch_from_logical_batch( yield Batch(logical_batch.batch_idx, batch, logical_batch) -def _format_batches( +def format_batches( block_iter: Iterator[Batch], - batch_format: str, + batch_format: Optional[str], ) -> Iterator[Batch]: """Given an iterator of blocks, returns an iterator of formatted batches. @@ -256,10 +258,10 @@ def _format_batches( yield batch -def _collate( +def collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], -) -> Iterator[Tuple[int, Any]]: +) -> Iterator[Batch]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. @@ -272,7 +274,7 @@ def _collate( yield batch -def _trace_deallocation( +def trace_deallocation( batch_iter: Iterator[Batch], eager_free: bool ) -> Iterator[Batch]: """Trace deallocation of the underlying block references for each batch. @@ -288,7 +290,7 @@ def _trace_deallocation( yield batch -def _restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: +def restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index fbe9f52ad57d5..fcaa861c0aeba 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -16,15 +16,15 @@ BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( - _bundle_block_refs_to_logical_batches, - _local_shuffle_logical_batches, - _prefetch_batches_locally, - _resolve_logical_batch, - _construct_batch_from_logical_batch, - _format_batches, - _collate, - _trace_deallocation, - _restore_from_original_order, + bundle_block_refs_to_logical_batches, + local_shuffle_logical_batches, + prefetch_batches_locally, + resolve_logical_batch, + construct_batch_from_logical_batch, + format_batches, + collate, + trace_deallocation, + restore_from_original_order, ) @@ -44,7 +44,7 @@ def block_generator( def logical_batch_generator( num_rows: int, num_blocks: int, batch_size: int = None ) -> Iterator[LogicalBatch]: - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( block_generator(num_rows=num_rows, num_blocks=num_blocks), batch_size=batch_size ) return logical_batch_iter @@ -66,7 +66,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): batch_size = None block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) block_refs = list(block_iter) - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( iter(block_refs), batch_size=batch_size ) logical_batches = list(logical_batch_iter) @@ -83,7 +83,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): batch_size = 1 block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) block_refs = list(block_iter) - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( iter(block_refs), batch_size=batch_size ) logical_batches = list(logical_batch_iter) @@ -100,7 +100,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): batch_size = 2 block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) block_refs = list(block_iter) - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( iter(block_refs), batch_size=batch_size ) logical_batches = list(logical_batch_iter) @@ -115,7 +115,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): batch_size = 3 block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) block_refs = list(block_iter) - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( iter(block_refs), batch_size=batch_size ) logical_batches = list(logical_batch_iter) @@ -132,7 +132,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): batch_size = 3 block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) block_refs = list(block_iter) - logical_batch_iter = _bundle_block_refs_to_logical_batches( + logical_batch_iter = bundle_block_refs_to_logical_batches( iter(block_refs), batch_size=batch_size, drop_last=True ) logical_batches = list(logical_batch_iter) @@ -152,7 +152,7 @@ def test_local_shuffle_logical_batches(ray_start_regular_shared): shuffle_buffer_min_size = 1 logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) shuffled_batches = list( - _local_shuffle_logical_batches( + local_shuffle_logical_batches( iter(logical_batches), shuffle_buffer_min_size=shuffle_buffer_min_size, shuffle_seed=shuffle_seed, @@ -167,7 +167,7 @@ def test_local_shuffle_logical_batches(ray_start_regular_shared): shuffle_buffer_min_size = 2 logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) shuffled_batches = list( - _local_shuffle_logical_batches( + local_shuffle_logical_batches( iter(logical_batches), shuffle_buffer_min_size=shuffle_buffer_min_size, shuffle_seed=shuffle_seed, @@ -194,7 +194,7 @@ def prefetch_blocks(self, blocks: List[Block]): num_blocks = 10 prefetcher = DummyPrefetcher() logical_batches = list(logical_batch_generator(1, num_blocks)) - prefetch_batch_iter = _prefetch_batches_locally( + prefetch_batch_iter = prefetch_batches_locally( iter(logical_batches), prefetcher=prefetcher, num_batches_to_prefetch=num_batches_to_prefetch, @@ -222,7 +222,7 @@ def prefetch_blocks(self, blocks: List[Block]): def test_resolve_logical_batches(ray_start_regular_shared): logical_batches = list(logical_batch_generator(1, 1)) - resolved_iter = _resolve_logical_batch(iter(logical_batches)) + resolved_iter = resolve_logical_batch(iter(logical_batches)) assert next(resolved_iter).blocks == ray.get(logical_batches[0].block_refs) @@ -234,7 +234,7 @@ def test_construct_batch_from_logical_batch(ray_start_regular_shared, block_size resolved_logical_batch_generator(block_size, num_blocks, batch_size=batch_size) ) - created_batches = list(_construct_batch_from_logical_batch(iter(logical_batches))) + created_batches = list(construct_batch_from_logical_batch(iter(logical_batches))) for i, batch in enumerate(created_batches): assert i == batch.batch_idx @@ -247,7 +247,7 @@ def test_format_batches(ray_start_regular_shared, batch_format): Batch(i, ray.get(data[0]), None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] - batch_iter = _format_batches(batches, batch_format=batch_format) + batch_iter = format_batches(batches, batch_format=batch_format) for i, batch in enumerate(batch_iter): assert batch.batch_idx == i @@ -268,7 +268,7 @@ def collate_fn(batch): Batch(i, ray.get(data[0]), None) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] - batch_iter = _collate(batches, collate_fn=collate_fn) + batch_iter = collate(batches, collate_fn=collate_fn) for i, batch in enumerate(batch_iter): assert batch.batch_idx == i @@ -279,7 +279,7 @@ def collate_fn(batch): @pytest.mark.parametrize("eager_free", [True, False]) def test_trace_deallocation(mock, eager_free): batches = [Batch(0, 0, LogicalBatch(0, [0], 0, None, 1))] - batch_iter = _trace_deallocation(iter(batches), eager_free=eager_free) + batch_iter = trace_deallocation(iter(batches), eager_free=eager_free) # Test that the underlying batch is not modified. assert next(batch_iter) == batches[0] mock.assert_called_once_with(0, loc="iter_batches", free=eager_free) @@ -293,7 +293,7 @@ def test_restore_from_original_order(): Batch(2, None, None), ] - ordered = list(_restore_from_original_order(iter(base_iterator))) + ordered = list(restore_from_original_order(iter(base_iterator))) idx = [batch.batch_idx for batch in ordered] assert idx == [0, 1, 2, 3] From 5549fb4963b361e61130fad2910a304c97858567 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 18:09:10 -0700 Subject: [PATCH 17/75] lint Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/block_batching.py | 6 +++--- python/ray/data/_internal/block_batching/iter_batches.py | 2 +- python/ray/data/_internal/execution/interfaces.py | 6 ++++-- python/ray/data/_internal/split.py | 6 +++--- python/ray/data/tests/block_batching/test_iter_batches.py | 4 ++-- python/ray/data/tests/test_util.py | 6 +++--- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 2c4b380eb6a1d..aa0f1d7bca435 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -12,7 +12,7 @@ ) from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats -from ray.data._internal.memory_tracing import trace_deallocation +from ray.data._internal.memory_tracing import trace_deallocation_for_batch from ray.data.block import Block, BlockAccessor, DataBatch from ray.data.context import DatasetContext from ray.types import ObjectRef @@ -254,7 +254,7 @@ def _prefetch_blocks( if num_blocks_to_prefetch == 0: for block_ref in block_ref_iter: yield block_ref - trace_deallocation( + trace_deallocation_for_batch( block_ref, "block_batching._prefetch_blocks", free=eager_free ) @@ -275,7 +275,7 @@ def _prefetch_blocks( except StopIteration: pass yield block_ref - trace_deallocation( + trace_deallocation_for_batch( block_ref, "block_batching._prefetch_blocks", free=eager_free ) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 6c7a14c7b2a9e..94b027ca08878 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -274,7 +274,7 @@ def collate( yield batch -def trace_deallocation( +def trace_deallocation_for_batch( batch_iter: Iterator[Batch], eager_free: bool ) -> Iterator[Batch]: """Trace deallocation of the underlying block references for each batch. diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index 41d02af51bc00..18cb59ae30c00 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -3,7 +3,7 @@ import ray from ray.data._internal.logical.interfaces import Operator -from ray.data._internal.memory_tracing import trace_deallocation +from ray.data._internal.memory_tracing import trace_deallocation_for_batch from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.stats import DatasetStats, StatsDict from ray.data.block import Block, BlockMetadata @@ -78,7 +78,9 @@ def destroy_if_owned(self) -> int: """ should_free = self.owns_blocks and DatasetContext.get_current().eager_free for b in self.blocks: - trace_deallocation(b[0], "RefBundle.destroy_if_owned", free=should_free) + trace_deallocation_for_batch( + b[0], "RefBundle.destroy_if_owned", free=should_free + ) return self.size_bytes() if should_free else 0 def get_cached_location(self) -> Optional[NodeIdStr]: diff --git a/python/ray/data/_internal/split.py b/python/ray/data/_internal/split.py index e66d2c9748583..331d2d5cd8050 100644 --- a/python/ray/data/_internal/split.py +++ b/python/ray/data/_internal/split.py @@ -5,7 +5,7 @@ import ray from ray.data._internal.block_list import BlockList from ray.data._internal.remote_fn import cached_remote_fn -from ray.data._internal.memory_tracing import trace_deallocation +from ray.data._internal.memory_tracing import trace_deallocation_for_batch from ray.data.block import ( Block, BlockPartition, @@ -205,10 +205,10 @@ def _split_all_blocks( # only be consumed by the owner). if owned_by_consumer: for b in blocks_splitted: - trace_deallocation(b, "split._split_all_blocks") + trace_deallocation_for_batch(b, "split._split_all_blocks") else: for b in blocks_splitted: - trace_deallocation(b, "split._split_all_blocks", free=False) + trace_deallocation_for_batch(b, "split._split_all_blocks", free=False) return itertools.chain.from_iterable(all_blocks_split_results) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index fcaa861c0aeba..21b0282430868 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -23,7 +23,7 @@ construct_batch_from_logical_batch, format_batches, collate, - trace_deallocation, + trace_deallocation_for_batch, restore_from_original_order, ) @@ -279,7 +279,7 @@ def collate_fn(batch): @pytest.mark.parametrize("eager_free", [True, False]) def test_trace_deallocation(mock, eager_free): batches = [Batch(0, 0, LogicalBatch(0, [0], 0, None, 1))] - batch_iter = trace_deallocation(iter(batches), eager_free=eager_free) + batch_iter = trace_deallocation_for_batch(iter(batches), eager_free=eager_free) # Test that the underlying batch is not modified. assert next(batch_iter) == batches[0] mock.assert_called_once_with(0, loc="iter_batches", free=eager_free) diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index f9bd0afbbc5d8..18ee47cc43b24 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -5,7 +5,7 @@ from ray.data._internal.util import _check_pyarrow_version, _split_list from ray.data._internal.memory_tracing import ( trace_allocation, - trace_deallocation, + trace_deallocation_for_batch, leak_report, ) from ray.data.tests.conftest import * # noqa: F401, F403 @@ -51,8 +51,8 @@ def test_memory_tracing(enabled): trace_allocation(ref1, "test1") trace_allocation(ref2, "test2") trace_allocation(ref3, "test5") - trace_deallocation(ref1, "test3", free=False) - trace_deallocation(ref2, "test4", free=True) + trace_deallocation_for_batch(ref1, "test3", free=False) + trace_deallocation_for_batch(ref2, "test4", free=True) ray.get(ref1) with pytest.raises(ray.exceptions.ObjectFreedError): ray.get(ref2) From c414c52d457f2f62546be23787d59934b6417d72 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 18:11:43 -0700 Subject: [PATCH 18/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/block_batching.py | 6 +++--- python/ray/data/_internal/execution/interfaces.py | 6 ++---- python/ray/data/_internal/split.py | 6 +++--- python/ray/data/tests/test_util.py | 6 +++--- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index aa0f1d7bca435..2c4b380eb6a1d 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -12,7 +12,7 @@ ) from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats -from ray.data._internal.memory_tracing import trace_deallocation_for_batch +from ray.data._internal.memory_tracing import trace_deallocation from ray.data.block import Block, BlockAccessor, DataBatch from ray.data.context import DatasetContext from ray.types import ObjectRef @@ -254,7 +254,7 @@ def _prefetch_blocks( if num_blocks_to_prefetch == 0: for block_ref in block_ref_iter: yield block_ref - trace_deallocation_for_batch( + trace_deallocation( block_ref, "block_batching._prefetch_blocks", free=eager_free ) @@ -275,7 +275,7 @@ def _prefetch_blocks( except StopIteration: pass yield block_ref - trace_deallocation_for_batch( + trace_deallocation( block_ref, "block_batching._prefetch_blocks", free=eager_free ) diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index 18cb59ae30c00..41d02af51bc00 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -3,7 +3,7 @@ import ray from ray.data._internal.logical.interfaces import Operator -from ray.data._internal.memory_tracing import trace_deallocation_for_batch +from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.stats import DatasetStats, StatsDict from ray.data.block import Block, BlockMetadata @@ -78,9 +78,7 @@ def destroy_if_owned(self) -> int: """ should_free = self.owns_blocks and DatasetContext.get_current().eager_free for b in self.blocks: - trace_deallocation_for_batch( - b[0], "RefBundle.destroy_if_owned", free=should_free - ) + trace_deallocation(b[0], "RefBundle.destroy_if_owned", free=should_free) return self.size_bytes() if should_free else 0 def get_cached_location(self) -> Optional[NodeIdStr]: diff --git a/python/ray/data/_internal/split.py b/python/ray/data/_internal/split.py index 331d2d5cd8050..e66d2c9748583 100644 --- a/python/ray/data/_internal/split.py +++ b/python/ray/data/_internal/split.py @@ -5,7 +5,7 @@ import ray from ray.data._internal.block_list import BlockList from ray.data._internal.remote_fn import cached_remote_fn -from ray.data._internal.memory_tracing import trace_deallocation_for_batch +from ray.data._internal.memory_tracing import trace_deallocation from ray.data.block import ( Block, BlockPartition, @@ -205,10 +205,10 @@ def _split_all_blocks( # only be consumed by the owner). if owned_by_consumer: for b in blocks_splitted: - trace_deallocation_for_batch(b, "split._split_all_blocks") + trace_deallocation(b, "split._split_all_blocks") else: for b in blocks_splitted: - trace_deallocation_for_batch(b, "split._split_all_blocks", free=False) + trace_deallocation(b, "split._split_all_blocks", free=False) return itertools.chain.from_iterable(all_blocks_split_results) diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 18ee47cc43b24..f9bd0afbbc5d8 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -5,7 +5,7 @@ from ray.data._internal.util import _check_pyarrow_version, _split_list from ray.data._internal.memory_tracing import ( trace_allocation, - trace_deallocation_for_batch, + trace_deallocation, leak_report, ) from ray.data.tests.conftest import * # noqa: F401, F403 @@ -51,8 +51,8 @@ def test_memory_tracing(enabled): trace_allocation(ref1, "test1") trace_allocation(ref2, "test2") trace_allocation(ref3, "test5") - trace_deallocation_for_batch(ref1, "test3", free=False) - trace_deallocation_for_batch(ref2, "test4", free=True) + trace_deallocation(ref1, "test3", free=False) + trace_deallocation(ref2, "test4", free=True) ray.get(ref1) with pytest.raises(ray.exceptions.ObjectFreedError): ray.get(ref2) From 44e63c45518fbf9b8d18a51d99e4beab8703eb5e Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 19:03:23 -0700 Subject: [PATCH 19/75] integration Signed-off-by: amogkam --- python/ray/data/_internal/batcher.py | 61 ------- .../block_batching/block_batching.py | 79 +++++++-- .../_internal/block_batching/iter_batches.py | 64 +++++-- .../ray/data/_internal/block_batching/util.py | 10 +- .../data/_internal/execution/legacy_compat.py | 8 +- python/ray/data/_internal/plan.py | 11 +- python/ray/data/dataset.py | 21 ++- python/ray/data/dataset_iterator.py | 156 +++++++++++++++--- python/ray/data/dataset_pipeline.py | 6 +- .../block_batching/test_block_batching.py | 41 ----- .../tests/block_batching/test_iter_batches.py | 7 +- .../data/tests/block_batching/test_util.py | 1 + python/ray/data/tests/test_batcher.py | 35 +--- 13 files changed, 289 insertions(+), 211 deletions(-) diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index a5c84cabfec84..d1358990dfcbd 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -319,64 +319,3 @@ def next_batch(self) -> Block: self._batch_head += batch_size # Yield the shuffled batch. return BlockAccessor.for_block(self._shuffle_buffer).take(batch_indices) - - -def _blocks_to_batches( - block_iter: Iterator[Block], - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, - batch_size: Optional[int] = None, - drop_last: bool = False, - shuffle_buffer_min_size: Optional[int] = None, - shuffle_seed: Optional[int] = None, - ensure_copy: bool = False, -) -> Iterator[Block]: - """Given an iterator over blocks, returns an iterator over blocks - of the appropriate bacth size. - - If the shuffling configurations are specified, then the - output blocks contain shuffled data. - - Args: - block_iter: An iterator over blocks. - stats: Dataset stats object used to store block batching time. - batch_size: Record batch size, or None to let the system pick. - drop_last: Whether to drop the last batch if it's incomplete. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - - Returns: - An iterator over blocks of the given size that are potentially shuffled. - """ - if shuffle_buffer_min_size is not None: - batcher = ShufflingBatcher( - batch_size=batch_size, - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - else: - batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) - - def get_iter_next_batch_s_timer(): - return stats.iter_next_batch_s.timer() if stats else nullcontext() - - for block in block_iter: - batcher.add(block) - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Signal to the batcher that there are no more blocks to add. - batcher.done_adding() - - # Get any leftover batches in ShufflingBatcher. - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Get any remaining data. - if not drop_last and batcher.has_any(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch \ No newline at end of file diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 8554d2ef99b4e..d866e9b621465 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -6,7 +6,6 @@ import ray from ray.data._internal.block_batching.interfaces import BlockPrefetcher from ray.data._internal.block_batching.util import ( - _make_async_gen, WaitBlockPrefetcher, ActorBlockPrefetcher, ) @@ -42,7 +41,6 @@ def batch_block_refs( shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, - prefetch_batches: int = 0, ) -> Iterator[DataBatch]: """Create formatted batches of data from 1 or more block object references. @@ -76,12 +74,6 @@ def batch_block_refs( shuffle_seed: The seed to use for the local random shuffle. ensure_copy: Whether batches are always copied from the underlying base blocks (not zero-copy views). - prefetch_batches: The number of batches to fetch ahead of the current batch to - process. If set to greater than 0, a separate thread will be used to fetch - the specified amount of formatted batches from blocks. This improves - performance for non-CPU bound UDFs, allowing batch fetching compute and - formatting to be overlapped with the UDF. Defaults to 0 (no prefetching - enabled). Returns: An iterator over record batches. @@ -119,7 +111,6 @@ def batch_block_refs( shuffle_buffer_min_size=shuffle_buffer_min_size, shuffle_seed=shuffle_seed, ensure_copy=ensure_copy, - prefetch_batches=prefetch_batches, ) @@ -134,7 +125,6 @@ def batch_blocks( shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, - prefetch_batches: int = 0, ) -> Iterator[DataBatch]: """Create formatted batches of data from 1 or more blocks. @@ -167,12 +157,7 @@ def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: batch_iter = batch_fn_iter(batch_iter) yield from batch_iter - if prefetch_batches > 0: - batch_iter = _make_async_gen( - blocks, fn=_iterator_fn, num_workers=prefetch_batches - ) - else: - batch_iter = _iterator_fn(blocks) + batch_iter = _iterator_fn(blocks) for formatted_batch in batch_iter: user_timer = stats.iter_user_s.timer() if stats else nullcontext() @@ -279,6 +264,68 @@ def _prefetch_blocks( block_ref, "block_batching._prefetch_blocks", free=eager_free ) + +def _blocks_to_batches( + block_iter: Iterator[Block], + stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, + batch_size: Optional[int] = None, + drop_last: bool = False, + shuffle_buffer_min_size: Optional[int] = None, + shuffle_seed: Optional[int] = None, + ensure_copy: bool = False, +) -> Iterator[Block]: + """Given an iterator over blocks, returns an iterator over blocks + of the appropriate bacth size. + + If the shuffling configurations are specified, then the + output blocks contain shuffled data. + + Args: + block_iter: An iterator over blocks. + stats: Dataset stats object used to store block batching time. + batch_size: Record batch size, or None to let the system pick. + drop_last: Whether to drop the last batch if it's incomplete. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + + Returns: + An iterator over blocks of the given size that are potentially shuffled. + """ + if shuffle_buffer_min_size is not None: + batcher = ShufflingBatcher( + batch_size=batch_size, + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + else: + batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) + + def get_iter_next_batch_s_timer(): + return stats.iter_next_batch_s.timer() if stats else nullcontext() + + for block in block_iter: + batcher.add(block) + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch + + # Signal to the batcher that there are no more blocks to add. + batcher.done_adding() + + # Get any leftover batches in ShufflingBatcher. + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch + + # Get any remaining data. + if not drop_last and batcher.has_any(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield batch + + def _format_batches( block_iter: Iterator[Block], batch_format: str, diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 9c8f5ea53605b..2c0d68a2ea351 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -12,7 +12,12 @@ LogicalBatch, BlockPrefetcher, ) -from ray.data._internal.block_batching.util import _calculate_ref_hits, _make_async_gen, ActorBlockPrefetcher, WaitBlockPrefetcher +from ray.data._internal.block_batching.util import ( + _calculate_ref_hits, + _make_async_gen, + ActorBlockPrefetcher, + WaitBlockPrefetcher, +) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetStats @@ -112,7 +117,9 @@ def iter_batches( eager_free = clear_block_after_read and context.eager_free - def _async_iter_batches(block_refs: Iterator[ObjectRef[Block]]) -> Iterator[DataBatch]: + def _async_iter_batches( + block_refs: Iterator[ObjectRef[Block]], + ) -> Iterator[DataBatch]: # Step 1: Construct logical batches based on the metadata. batch_iter = _bundle_block_refs_to_logical_batches( block_refs, batch_size=batch_size, drop_last=drop_last @@ -135,14 +142,21 @@ def _async_iter_batches(block_refs: Iterator[ObjectRef[Block]]) -> Iterator[Data # Step 4: Use a threadpool for resolving blocks, slicing, formatting, and # collation. - batch_iter = _batch_in_threadpool(batch_iter, stats=stats, batch_format=batch_format, collate_fn=collate_fn, ensure_copy=ensure_copy, num_threadpool_workers=prefetch_batches) + batch_iter = _batch_in_threadpool( + batch_iter, + stats=stats, + batch_format=batch_format, + collate_fn=collate_fn, + ensure_copy=ensure_copy, + num_threadpool_workers=prefetch_batches, + ) # Step 5: Trace deallocation batch_iter = _trace_deallocation(batch_iter, eager_free=eager_free) # Step 6: Restore original order. batch_iter: Iterator[Batch] = _restore_from_original_order(batch_iter) - + for batch in batch_iter: yield batch.data @@ -161,16 +175,17 @@ def _async_iter_batches(block_refs: Iterator[ObjectRef[Block]]) -> Iterator[Data with stats.iter_user_s.timer() if stats else nullcontext(): yield next_batch + def _batch_in_threadpool( - logical_batch_iterator: Iterator[LogicalBatch], - stats: DatasetStats, - batch_format: str = "default", - collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ensure_copy: bool = False, - num_threadpool_workers: int = 0, + logical_batch_iterator: Iterator[LogicalBatch], + stats: DatasetStats, + batch_format: str = "default", + collate_fn: Optional[Callable[[DataBatch], Any]] = None, + ensure_copy: bool = False, + num_threadpool_workers: int = 0, ) -> Iterator[Batch]: """Executes the batching, formatting, and collation logic in a threadpool. - + Args: logical_batch_iterator: An iterator over logical batches. stats: DatasetStats object to record timing and other statistics. @@ -185,7 +200,9 @@ def _batch_in_threadpool( num_threadpool_workers: The number of threads to use in the threadpool. """ - def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]) -> Iterator[Batch]: + def threadpool_computations( + logical_batch_iter: Iterator[LogicalBatch], + ) -> Iterator[Batch]: # Step 4.1: Resolve the blocks. resolved_batch_iter = _resolve_logical_batch(logical_batch_iter, stats=stats) @@ -206,7 +223,12 @@ def threadpool_computations(logical_batch_iter: Iterator[LogicalBatch]) -> Itera ) yield from formatted_batch_iter - return _make_async_gen(base_iterator=logical_batch_iterator, fn=threadpool_computations, num_workers=num_threadpool_workers) + return _make_async_gen( + base_iterator=logical_batch_iterator, + fn=threadpool_computations, + num_workers=num_threadpool_workers, + ) + def _bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], @@ -370,15 +392,20 @@ def get_next_batches() -> Iterator[List[LogicalBatch]]: for batch in batches: yield batch -def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch], stats: Optional[DatasetStats]=None): + +def _resolve_logical_batch( + logical_batch_iter: Iterator[LogicalBatch], stats: Optional[DatasetStats] = None +): """Resolves the block references for each logical batch.""" hits = 0 misses = 0 unknowns = 0 - + for logical_batch in logical_batch_iter: - current_hit, current_miss, current_unknown = _calculate_ref_hits(logical_batch.block_refs) + current_hit, current_miss, current_unknown = _calculate_ref_hits( + logical_batch.block_refs + ) hits += current_hit misses += current_miss unknowns += current_unknown @@ -392,10 +419,11 @@ def _resolve_logical_batch(logical_batch_iter: Iterator[LogicalBatch], stats: Op stats.iter_blocks_remote += misses stats.iter_unknown_location += unknowns + def _construct_batch_from_logical_batch( resolved_logical_batch_iter: Iterator[LogicalBatch], ensure_copy: bool = False, - stats: Optional[DatasetStats] = None + stats: Optional[DatasetStats] = None, ) -> Iterator[Tuple[int, Block]]: """Given an iterator over logical batches, returns an iterator over actual constructed batches. @@ -473,7 +501,7 @@ def _format_batches( def _collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], - stats: Optional[DatasetStats] = None + stats: Optional[DatasetStats] = None, ) -> Iterator[Tuple[int, Any]]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 5ac858dacaa2d..0453e38baa1cd 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -6,7 +6,6 @@ import ray from ray.types import ObjectRef from ray.actor import ActorHandle -from ray.types import ObjectRef from ray.data.block import Block from ray.data._internal.block_batching.interfaces import BlockPrefetcher from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -41,7 +40,7 @@ def _make_async_gen( if num_workers <= 0: yield from fn(base_iterator) return - + def convert_to_threadsafe_iterator(base_iterator: Iterator[T]) -> Iterator[T]: class ThreadSafeIterator: def __init__(self, it): @@ -98,8 +97,11 @@ def execute_computation(thread_index: int): if num_threads_finished >= num_workers: break + def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: - """Given a list of object references, returns how many are already on the local node, how many require fetching from another node, and how many have unknown locations.""" + """Given a list of object references, returns how many are already on the local + node, how many require fetching from another node, and how many have unknown + locations.""" current_node_id = ray.get_runtime_context().get_node_id() locs = ray.experimental.get_object_locations(refs) @@ -149,4 +151,4 @@ class _BlockPretcher: """Helper actor that prefetches blocks asynchronously.""" def prefetch(self, *blocks) -> None: - pass \ No newline at end of file + pass diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 21783d917bfad..dc786eefe8701 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -41,14 +41,14 @@ def execute_to_legacy_block_iterator( plan: ExecutionPlan, allow_clear_input_blocks: bool, dataset_uuid: str, -) -> Iterator[ObjectRef[Block]]: - """Same as execute_to_legacy_bundle_iterator but returning blocks.""" +) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: + """Same as execute_to_legacy_bundle_iterator but returning blocks and metadata.""" bundle_iter = execute_to_legacy_bundle_iterator( executor, plan, allow_clear_input_blocks, dataset_uuid ) for bundle in bundle_iter: - for block, _ in bundle.blocks: - yield block + for block, metadata in bundle.blocks: + yield block, metadata def execute_to_legacy_bundle_iterator( diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index acce313b959ac..0ef8cf68ac100 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -16,6 +16,7 @@ ) import ray +from ray.data.block import BlockMetadata from ray.data._internal.util import capitalize from ray.types import ObjectRef from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas @@ -481,7 +482,11 @@ def execute_to_iterator( self, allow_clear_input_blocks: bool = True, force_read: bool = False, - ) -> Tuple[Iterator[ObjectRef[Block]], DatasetStats, Optional["Executor"]]: + ) -> Tuple[ + Iterator[Tuple[ObjectRef[Block], BlockMetadata]], + DatasetStats, + Optional["Executor"], + ]: """Execute this plan, returning an iterator. If the streaming execution backend is enabled, this will use streaming @@ -499,7 +504,9 @@ def execute_to_iterator( ctx = DatasetContext.get_current() if not ctx.use_streaming_executor: return ( - self.execute(allow_clear_input_blocks, force_read).iter_blocks(), + self.execute( + allow_clear_input_blocks, force_read + ).iter_blocks_with_metadata(), self._snapshot_stats, None, ) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index b6ea124df92c6..1cd1c3dd54186 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -51,8 +51,12 @@ from ray.data._internal.planner.write import generate_write_fn from ray.data.dataset_iterator import DatasetIterator from ray.data._internal.block_list import BlockList -from ray.data._internal.dataset_iterator.dataset_iterator_impl import DatasetIteratorImpl -from ray.data._internal.dataset_iterator.stream_split_dataset_iterator import StreamSplitDatasetIterator +from ray.data._internal.dataset_iterator.dataset_iterator_impl import ( + DatasetIteratorImpl, +) +from ray.data._internal.dataset_iterator.stream_split_dataset_iterator import ( + StreamSplitDatasetIterator, +) from ray.data._internal.compute import ( ActorPoolStrategy, CallableClass, @@ -2927,13 +2931,15 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] def iter_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: Optional[int] = 256, batch_format: Optional[str] = "default", drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + # Deprecated. + prefetch_blocks: int = 0, ) -> Iterator[DataBatch]: """Return a local batched iterator over the dataset. @@ -2945,8 +2951,12 @@ def iter_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -2976,6 +2986,7 @@ def iter_batches( ) return self.iterator().iter_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, batch_format=batch_format, diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 374b3a89895a7..1dc2463d74331 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -1,10 +1,26 @@ import abc import numpy as np -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, Iterator - -from ray.data.block import BlockAccessor, DataBatch, T +import time +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + Iterator, +) + +from ray.types import ObjectRef +from ray.data.block import BlockAccessor, Block, BlockMetadata, DataBatch, T +from ray.data.context import DatasetContext from ray.data.row import TableRow from ray.util.annotations import PublicAPI +from ray.data._internal.block_batching import batch_block_refs +from ray.data._internal.block_batching.iter_batches import iter_batches +from ray.data._internal.stats import DatasetStats from ray.data._internal.util import _is_tensor_schema if TYPE_CHECKING: @@ -48,17 +64,34 @@ class DatasetIterator(abc.ABC): :class:`~ray.data.Preprocessor`, and a :class:`~ray.air.DatasetConfig`. """ + @abc.abstractmethod + def _to_block_iterator( + self, + ) -> Tuple[ + Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats] + ]: + """Returns the iterator to use for `iter_batches`. + + Returns: + A tuple. The first item of the tuple is an iterator over pairs of Block + object references and their corresponding metadata. The second item of the + tuple is a DatasetStats object used for recording stats during iteration. + """ + raise NotImplementedError + @abc.abstractmethod def iter_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: int = 256, batch_format: Optional[str] = "default", drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + # Deprecated. + prefetch_blocks: int = 0, ) -> Iterator[DataBatch]: """Return a local batched iterator over the dataset. @@ -72,8 +105,12 @@ def iter_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -97,7 +134,58 @@ def iter_batches( Returns: An iterator over record batches. """ - raise NotImplementedError + + context = DatasetContext.get_current() + if not context.use_streaming_executor: + # Always use legacy iter_batches for bulk executor. + use_legacy = True + else: + use_legacy = context.use_legacy_iter_batches + + if prefetch_blocks > 0 and not use_legacy: + raise DeprecationWarning( + "`prefetch_blocks` arg is deprecated in Ray 2.4. Use " + "the`prefetch_batches` arg instead to specify the amount of " + "prefetching in terms of batches instead of blocks. If you " + "would like to use the legacy `iter_batches` codepath, " + "you can enable it by setting `use_legacy_iter_batches` " + "to True in the DatasetContext." + ) + + time_start = time.perf_counter() + + block_iterator, stats = self._to_block_iterator() + if use_legacy: + # Legacy iter_batches does not use metadata. + def drop_metadata(block_iterator): + for block_ref, metadata in block_iterator: + yield block_ref + + yield from batch_block_refs( + drop_metadata(block_iterator), + stats=stats, + prefetch_blocks=prefetch_blocks, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + collate_fn=_collate_fn, + shuffle_buffer_min_size=local_shuffle_buffer_size, + shuffle_seed=local_shuffle_seed, + ) + else: + yield from iter_batches( + block_iterator, + stats=stats, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + collate_fn=_collate_fn, + shuffle_buffer_min_size=local_shuffle_buffer_size, + shuffle_seed=local_shuffle_seed, + prefetch_batches=prefetch_batches, + ) + + stats.iter_total_s.add(time.perf_counter() - time_start) def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]: """Return a local row iterator over the dataset. @@ -123,6 +211,8 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] """ for batch in self.iter_batches( batch_size=None, + # If batch_size is None, 1 block is exactly 1 batch. + prefetch_batches=prefetch_blocks, prefetch_blocks=prefetch_blocks, batch_format=None, ): @@ -143,7 +233,7 @@ def schema(self) -> Union[type, "pyarrow.lib.Schema"]: def iter_torch_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, @@ -153,6 +243,8 @@ def iter_torch_batches( drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated. + prefetch_blocks: int = 0, ) -> Iterator["TorchTensorBatchType"]: """Return a local batched iterator of Torch Tensors over the dataset. @@ -171,8 +263,12 @@ def iter_torch_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -229,6 +325,7 @@ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): ) yield from self.iter_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, batch_format="numpy", @@ -241,12 +338,14 @@ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): def iter_tf_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: Optional[int] = 256, dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated. + prefetch_blocks: int = 0, ) -> Iterator["TensorFlowTensorBatchType"]: """Return a local batched iterator of TensorFlow Tensors over the dataset. @@ -272,8 +371,12 @@ def iter_tf_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -299,6 +402,7 @@ def iter_tf_batches( ) for batch in self.iter_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, batch_format="numpy", @@ -320,12 +424,14 @@ def to_torch( Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] ] = None, batch_size: int = 1, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, unsqueeze_label_tensor: bool = True, unsqueeze_feature_tensors: bool = True, + # Deprecated. + prefetch_blocks: int = 0, ) -> "torch.utils.data.IterableDataset": """Return a Torch IterableDataset over this dataset. @@ -386,8 +492,12 @@ def to_torch( all tensors. If None, then automatically infer the dtype. batch_size: How many samples per batch to yield at a time. Defaults to 1. - prefetch_blocks: The number of blocks to prefetch ahead of - the current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch @@ -458,6 +568,7 @@ def make_generator(): batch_size=batch_size, batch_format="pandas", prefetch_blocks=prefetch_blocks, + prefetch_batches=prefetch_batches, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, @@ -502,11 +613,13 @@ def to_tf( feature_columns: Union[str, List[str]], label_columns: Union[str, List[str]], *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated. + prefetch_blocks: int = 0, ) -> "tf.data.Dataset": """Return a TF Dataset over this dataset. @@ -573,8 +686,12 @@ def to_tf( label_column: Columns that correspond to model targets. If this is a string, the target data is a tensor. If this is a list, the target data is a ``dict`` that maps column names to their tensor representation. - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If @@ -655,6 +772,7 @@ def convert_batch_to_tensors( def generator(): for batch in self.iter_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, drop_last=drop_last, diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index c07e6c177b8ab..c02062b3d2c9f 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -28,7 +28,9 @@ PipelineExecutor, PipelineSplitExecutorCoordinator, ) -from ray.data._internal.dataset_iterator.pipelined_dataset_iterator import PipelinedDatasetIterator +from ray.data._internal.dataset_iterator.pipelined_dataset_iterator import ( + PipelinedDatasetIterator, +) from ray.data._internal.plan import ExecutionPlan from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data.block import BatchUDF, Block, DataBatch, KeyFn, RowUDF @@ -793,7 +795,6 @@ def map_batches( batch_size: Optional[Union[int, Literal["default"]]] = "default", compute: Optional[Union[str, ComputeStrategy]] = None, batch_format: Optional[str] = "default", - prefetch_batches: int = 0, fn_args: Optional[Iterable[Any]] = None, fn_kwargs: Optional[Dict[str, Any]] = None, fn_constructor_args: Optional[Iterable[Any]] = None, @@ -808,7 +809,6 @@ def map_batches( batch_size=batch_size, compute=compute, batch_format=batch_format, - prefetch_batches=prefetch_batches, fn_args=fn_args, fn_kwargs=fn_kwargs, fn_constructor_args=fn_constructor_args, diff --git a/python/ray/data/tests/block_batching/test_block_batching.py b/python/ray/data/tests/block_batching/test_block_batching.py index 357eff91cc42f..1cc6d727deaea 100644 --- a/python/ray/data/tests/block_batching/test_block_batching.py +++ b/python/ray/data/tests/block_batching/test_block_batching.py @@ -1,5 +1,4 @@ import pytest -import time from typing import List from unittest import mock @@ -122,46 +121,6 @@ def test_format_batches(batch_format): assert isinstance(batch["foo"], np.ndarray) -# Test for 3 cases -# 1. Batch size is less than block size -# 2. Batch size is more than block size -# 3. Block size is not divisble by batch size -@pytest.mark.parametrize("batch_size", [4, 10, 7]) -def test_async_batch_fetching(batch_size): - blocks = block_generator(num_blocks=5, num_rows=8) - - def sleep_batch_format(batch_iter, *args, **kwargs): - for batch in batch_iter: - time.sleep(2) - yield batch - - with mock.patch( - "ray.data._internal.block_batching.block_batching._format_batches", - sleep_batch_format, - ): - batch_iter = batch_blocks( - batch_size=batch_size, blocks=blocks, prefetch_batches=1 - ) - outputs = [] - start_time = time.time() - for batch in batch_iter: - time.sleep(3) - outputs.append(batch) - end_time = time.time() - - total_time = end_time - start_time - # Total time should be based on number of times the udf is called - # (which is equal to len(outputs)). - # The 2 seconds sleep in sleep_batch_format is overlapped, so does not count - # towards total time. - assert total_time < len(outputs) * 3 + 3 - - # There should be no dropped rows. - assert sum(len(output_batch) for output_batch in outputs) == 40, sum( - len(output_batch) for output_batch in outputs - ) # 5 blocks with 8 rows each. - - if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 8e02f96f52699..ef8da38fa52d4 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -299,16 +299,14 @@ def test_restore_from_original_order(): idx = [batch.batch_idx for batch in ordered] assert idx == [0, 1, 2, 3] + # Test for 3 cases # 1. Batch size is less than block size # 2. Batch size is more than block size # 3. Block size is not divisble by batch size @pytest.mark.parametrize("batch_size", [1, 4, 3]) @pytest.mark.parametrize("drop_last", [True, False]) -def test_iter_batches_e2e( - ray_start_regular_shared, batch_size, drop_last -): - +def test_iter_batches_e2e(ray_start_regular_shared, batch_size, drop_last): def collate_fn(batch: pd.DataFrame): return batch + 1 @@ -348,6 +346,7 @@ def test_iter_batches_e2e_async(ray_start_regular_shared): 2. In the collate_fn to simulate expensive slicing/formatting/collation 3. In the user thread to simulate training. """ + def collate_fn(batch): time.sleep(2) return batch diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index a7f7ece276812..1cd81de2b7bea 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -92,6 +92,7 @@ def test_calculate_ref_hits(ray_start_regular_shared): assert misses == 0 assert unknowns == 0 + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_batcher.py b/python/ray/data/tests/test_batcher.py index 2262a2c30bc91..20769fb4f85ea 100644 --- a/python/ray/data/tests/test_batcher.py +++ b/python/ray/data/tests/test_batcher.py @@ -2,7 +2,7 @@ import pyarrow as pa -from ray.data._internal.batcher import ShufflingBatcher, _blocks_to_batches +from ray.data._internal.batcher import ShufflingBatcher def gen_block(num_rows): @@ -127,39 +127,6 @@ def next_and_check( ) -@pytest.mark.parametrize("block_size", [1, 10]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_blocks_to_batches(block_size, drop_last): - def block_generator(num_rows, num_blocks): - for _ in range(num_blocks): - yield gen_block(num_rows) - - num_blocks = 5 - block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) - - batch_size = 3 - batch_iter = _blocks_to_batches( - block_iter, batch_size=batch_size, drop_last=drop_last - ) - - if drop_last: - for batch in batch_iter: - assert len(batch) == batch_size - else: - full_batches = 0 - leftover_batches = 0 - - dataset_size = block_size * num_blocks - for batch in batch_iter: - if len(batch) == batch_size: - full_batches += 1 - if len(batch) == (dataset_size % batch_size): - leftover_batches += 1 - - assert leftover_batches == 1 - assert full_batches == (dataset_size // batch_size) - - if __name__ == "__main__": import sys From 8bed11e10ed8f091946115c50922edb6caddf127 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 19:10:16 -0700 Subject: [PATCH 20/75] stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index b1d2abfe48f21..d19700cc13f5c 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -286,10 +286,11 @@ def to_summary(self) -> "DatasetStatsSummary": ) iter_stats = IterStatsSummary( - self.iter_wait_s, self.iter_get_s, - self.iter_next_batch_s, + self.iter_create_batch_s, self.iter_format_batch_s, + self.iter_collate_batch_s, + self.iter_total_blocked_s, self.iter_user_s, self.iter_total_s, self.iter_blocks_local, @@ -620,14 +621,16 @@ def __str__(self) -> str: @dataclass class IterStatsSummary: - # Time spent in `ray.wait()`, in seconds - wait_time: Timer # Time spent in `ray.get()`, in seconds get_time: Timer - # Time spent in `batcher.next_batch()`, in seconds + # Time spent in batch building, in seconds next_time: Timer # Time spent in `_format_batch_()`, in seconds format_time: Timer + # Time spent in collate fn, in seconds + collate_time: Timer + # Total time user thread is blocked by iter_batches + block_time: Timer # Time spent in user code, in seconds user_time: Timer # Total time taken by Dataset iterator, in seconds @@ -649,17 +652,20 @@ def __str__(self) -> str: or self.get_time.get() ): out += "\nDataset iterator time breakdown:\n" - out += "* In ray.wait(): {}\n".format(fmt(self.wait_time.get())) - out += "* In ray.get(): {}\n".format(fmt(self.get_time.get())) + out += "* Total time user code is blocked: {}\n".format(fmt(self.block_time.get())) + out += "* Total time in user code: {}\n".format(fmt(self.user_time.get())) + out += "* Total time overall: {}\n".format(fmt(self.total_time.get())) out += "* Num blocks local: {}\n".format(self.iter_blocks_local) out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) out += "* Num blocks unknown location: {}\n".format( self.iter_unknown_location ) - out += "* In next_batch(): {}\n".format(fmt(self.next_time.get())) - out += "* In format_batch(): {}\n".format(fmt(self.format_time.get())) - out += "* In user code: {}\n".format(fmt(self.user_time.get())) - out += "* Total time: {}\n".format(fmt(self.total_time.get())) + out += "* Batch iteration time breakdown:\n" + out += " * In ray.get(): {}\n".format(fmt(self.get_time.get())) + out += " * In batch creation: {}\n".format(fmt(self.next_time.get())) + out += " * In batch formatting: {}\n".format(fmt(self.format_time.get())) + out += " * In collate_fn: {}\n".format(fmt(self.collate_time.get())) + return out From 11338810e76b67ecc72214f670f41e4a70d16efc Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 19:43:46 -0700 Subject: [PATCH 21/75] legacy stats Signed-off-by: amogkam --- .../block_batching/block_batching.py | 10 ++--- python/ray/data/_internal/stats.py | 44 +++++++++++++++++-- python/ray/data/dataset_iterator.py | 1 - 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index d866e9b621465..8cbd20654a3a2 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -78,7 +78,7 @@ def batch_block_refs( Returns: An iterator over record batches. """ - + stats._legacy_iter_batches = True context = DatasetContext.get_current() if ( @@ -248,15 +248,13 @@ def _prefetch_blocks( sliding_window = collections.deque( itertools.islice(block_ref_iter, window_size), maxlen=window_size ) - with stats.iter_wait_s.timer() if stats else nullcontext(): - prefetcher.prefetch_blocks(list(sliding_window)) + prefetcher.prefetch_blocks(list(sliding_window)) while sliding_window: block_ref = sliding_window.popleft() try: sliding_window.append(next(block_ref_iter)) - with stats.iter_wait_s.timer() if stats else nullcontext(): - prefetcher.prefetch_blocks(list(sliding_window)) + prefetcher.prefetch_blocks(list(sliding_window)) except StopIteration: pass yield block_ref @@ -301,7 +299,7 @@ def _blocks_to_batches( batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) def get_iter_next_batch_s_timer(): - return stats.iter_next_batch_s.timer() if stats else nullcontext() + return stats.iter_create_batch_s.timer() if stats else nullcontext() for block in block_iter: batcher.add(block) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index d19700cc13f5c..d18aba9357db5 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -216,6 +216,7 @@ def __init__( self.needs_stats_actor = needs_stats_actor self.stats_uuid = stats_uuid + self._legacy_iter_batches = False # Iteration stats, filled out if the user iterates over the dataset. self.iter_get_s: Timer = Timer() self.iter_create_batch_s: Timer = Timer() @@ -286,6 +287,7 @@ def to_summary(self) -> "DatasetStatsSummary": ) iter_stats = IterStatsSummary( + self._legacy_iter_batches, self.iter_get_s, self.iter_create_batch_s, self.iter_format_batch_s, @@ -621,6 +623,8 @@ def __str__(self) -> str: @dataclass class IterStatsSummary: + # Whether the legacy `iter_batches` is being used. + legacy_iter_batches: bool # Time spent in `ray.get()`, in seconds get_time: Timer # Time spent in batch building, in seconds @@ -643,16 +647,26 @@ class IterStatsSummary: iter_unknown_location: int def __str__(self) -> str: + if self.legacy_iter_batches: + return self.to_string_legacy() + else: + return self.to_string() + + def to_string(self) -> str: out = "" if ( self.total_time.get() - or self.wait_time.get() or self.next_time.get() or self.format_time.get() or self.get_time.get() + or self.user_time.get() + or self.block_time.get() + or self.collate_time.get() ): out += "\nDataset iterator time breakdown:\n" - out += "* Total time user code is blocked: {}\n".format(fmt(self.block_time.get())) + out += "* Total time user code is blocked: {}\n".format( + fmt(self.block_time.get()) + ) out += "* Total time in user code: {}\n".format(fmt(self.user_time.get())) out += "* Total time overall: {}\n".format(fmt(self.total_time.get())) out += "* Num blocks local: {}\n".format(self.iter_blocks_local) @@ -665,7 +679,31 @@ def __str__(self) -> str: out += " * In batch creation: {}\n".format(fmt(self.next_time.get())) out += " * In batch formatting: {}\n".format(fmt(self.format_time.get())) out += " * In collate_fn: {}\n".format(fmt(self.collate_time.get())) - + + return out + + def to_string_legacy(self) -> str: + """Iteration stats summary for legacy `iter_batches`.""" + + out = "" + if ( + self.total_time.get() + or self.wait_time.get() + or self.next_time.get() + or self.format_time.get() + or self.get_time.get() + ): + out += "\nDataset iterator time breakdown:\n" + out += "* In ray.get(): {}\n".format(fmt(self.get_time.get())) + out += "* Num blocks local: {}\n".format(self.iter_blocks_local) + out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) + out += "* Num blocks unknown location: {}\n".format( + self.iter_unknown_location + ) + out += "* In next_batch(): {}\n".format(fmt(self.next_time.get())) + out += "* In format_batch(): {}\n".format(fmt(self.format_time.get())) + out += "* In user code: {}\n".format(fmt(self.user_time.get())) + out += "* Total time: {}\n".format(fmt(self.total_time.get())) return out diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 1dc2463d74331..7b9f3315a33e6 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -79,7 +79,6 @@ def _to_block_iterator( """ raise NotImplementedError - @abc.abstractmethod def iter_batches( self, *, From 0a1885f07ca7551ba1c2d13505991aa6f6f21d3e Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 20:55:04 -0700 Subject: [PATCH 22/75] update Signed-off-by: amogkam --- python/ray/data/dataset.py | 52 +++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 1cd1c3dd54186..c8f8ca45bf661 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3000,7 +3000,7 @@ def iter_batches( def iter_torch_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, @@ -3010,6 +3010,8 @@ def iter_torch_batches( drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated + prefetch_blocks: int = 0, ) -> Iterator["TorchTensorBatchType"]: """Return a local batched iterator of Torch Tensors over the dataset. @@ -3032,8 +3034,12 @@ def iter_torch_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -3064,6 +3070,7 @@ def iter_torch_batches( An iterator over Torch Tensor batches. """ return self.iterator().iter_torch_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, dtypes=dtypes, @@ -3078,12 +3085,14 @@ def iter_torch_batches( def iter_tf_batches( self, *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: Optional[int] = 256, dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated + prefetch_blocks: int = 0, ) -> Iterator[TensorFlowTensorBatchType]: """Return a local batched iterator of TensorFlow Tensors over the dataset. @@ -3109,8 +3118,12 @@ def iter_tf_batches( Time complexity: O(1) Args: - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -3132,6 +3145,7 @@ def iter_tf_batches( An iterator over TensorFlow Tensor batches. """ return self.iterator().iter_tf_batches( + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, batch_size=batch_size, dtypes=dtypes, @@ -3153,12 +3167,14 @@ def to_torch( Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] ] = None, batch_size: int = 1, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, unsqueeze_label_tensor: bool = True, unsqueeze_feature_tensors: bool = True, + # Deprecated + prefetch_blocks: int = 0, ) -> "torch.utils.data.IterableDataset": """Return a Torch IterableDataset over this dataset. @@ -3219,8 +3235,12 @@ def to_torch( all tensors. If None, then automatically infer the dtype. batch_size: How many samples per batch to yield at a time. Defaults to 1. - prefetch_blocks: The number of blocks to prefetch ahead of - the current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch @@ -3255,6 +3275,7 @@ def to_torch( feature_column_dtypes=feature_column_dtypes, batch_size=batch_size, prefetch_blocks=prefetch_blocks, + prefetch_batches=prefetch_batches, drop_last=drop_last, local_shuffle_buffer_size=local_shuffle_buffer_size, local_shuffle_seed=local_shuffle_seed, @@ -3268,11 +3289,13 @@ def to_tf( feature_columns: Union[str, List[str]], label_columns: Union[str, List[str]], *, - prefetch_blocks: int = 0, + prefetch_batches: int = 0, batch_size: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, + # Deprecated + prefetch_blocks: int = 0, ) -> "tf.data.Dataset": """Return a TF Dataset over this dataset. @@ -3337,8 +3360,12 @@ def to_tf( label_column: Columns that correspond to model targets. If this is a string, the target data is a tensor. If this is a list, the target data is a ``dict`` that maps column names to their tensor representation. - prefetch_blocks: The number of blocks to prefetch ahead of the - current block during the scan. + prefetch_batches: The number of batches to fetch ahead of the current batch + to fetch. If set to greater than 0, a separate threadpool will be used + to fetch the objects to the local node, format the batches, and apply + the collate_fn. Defaults to 0 (no prefetching enabled.) This is still + an alpha API. You can revert back to the old prefetching behavior by + setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If @@ -3367,6 +3394,7 @@ def to_tf( return self.iterator().to_tf( feature_columns=feature_columns, label_columns=label_columns, + prefetch_batches=prefetch_batches, prefetch_blocks=prefetch_blocks, drop_last=drop_last, batch_size=batch_size, From 34213f24006e8aac6971766e66022255a01f9b61 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 20:56:55 -0700 Subject: [PATCH 23/75] add iterator Signed-off-by: amogkam --- .../dataset_iterator/dataset_iterator_impl.py | 57 +++++ .../pipelined_dataset_iterator.py | 59 +++++ .../stream_split_dataset_iterator.py | 241 ++++++++++++++++++ 3 files changed, 357 insertions(+) create mode 100644 python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py create mode 100644 python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py create mode 100644 python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py diff --git a/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py b/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py new file mode 100644 index 0000000000000..e76d7236bba7e --- /dev/null +++ b/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING, Optional, Union, Iterator, Tuple +import time +import warnings + +from ray.types import ObjectRef +from ray.data.block import Block, BlockMetadata +from ray.data.context import DatasetContext +from ray.data.dataset_iterator import DatasetIterator +from ray.data._internal.stats import DatasetStats + +if TYPE_CHECKING: + import pyarrow + from ray.data import Dataset + + +class DatasetIteratorImpl(DatasetIterator): + def __init__( + self, + base_dataset: "Dataset", + ): + self._base_dataset = base_dataset + self._base_context = DatasetContext.get_current() + + def __repr__(self) -> str: + return f"DatasetIterator({self._base_dataset})" + + def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: + ds = self._base_dataset + block_iterator, stats, executor = ds._plan.execute_to_iterator() + ds._current_executor = executor + return block_iterator, stats + + def stats(self) -> str: + return self._base_dataset.stats() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + return self._base_dataset.schema() + + def __getattr__(self, name): + if name == "_base_dataset": + raise AttributeError() + + if hasattr(self._base_dataset, name) and not name.startswith("_"): + # Warning for backwards compatibility. TODO: remove this method in 2.5. + warnings.warn( + "session.get_dataset_shard returns a ray.data.DatasetIterator " + "instead of a Dataset/DatasetPipeline as of Ray v2.3. " + "Use iter_torch_batches(), to_tf(), or iter_batches() to " + "iterate over one epoch. See " + "https://docs.ray.io/en/latest/data/api/dataset_iterator.html " + "for full DatasetIterator docs.", + stacklevel=4, + ) + + return getattr(self._base_dataset, name) + + raise AttributeError() diff --git a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py new file mode 100644 index 0000000000000..7e687b6d3d72a --- /dev/null +++ b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING,Optional, Union, Iterator, Tuple +import warnings + +from ray.types import ObjectRef +from ray.data.block import Block, BlockMetadata +from ray.data.dataset_iterator import DatasetIterator +from ray.data._internal.stats import DatasetStats + +if TYPE_CHECKING: + import pyarrow + from ray.data import DatasetPipeline + + +class PipelinedDatasetIterator(DatasetIterator): + def __init__( + self, + base_dataset_pipeline: "DatasetPipeline", + ): + self._base_dataset_pipeline = base_dataset_pipeline + self._epoch_iterator = None + + def __repr__(self) -> str: + return f"DatasetIterator({self._base_dataset_pipeline})" + + def _get_next_dataset(self) -> "DatasetPipeline": + if self._epoch_iterator is None: + self._epoch_iterator = self._base_dataset_pipeline.iter_epochs() + + ds = next(self._epoch_iterator) + return ds + + def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: + ds = self._get_next_dataset() + return ds.iterator()._to_block_iterator() + + def stats(self) -> str: + return self._base_dataset_pipeline.stats() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + return self._base_dataset_pipeline.schema() + + def __getattr__(self, name): + if name == "_base_dataset_pipeline": + raise AttributeError + + if hasattr(self._base_dataset_pipeline, name) and not name.startswith("_"): + # Warning for backwards compatibility. TODO: remove this method in 2.5. + warnings.warn( + "session.get_dataset_shard returns a ray.data.DatasetIterator " + "instead of a Dataset/DatasetPipeline as of Ray v2.3. " + "Use iter_torch_batches(), to_tf(), or iter_batches() to " + "iterate over one epoch. See " + "https://docs.ray.io/en/latest/data/api/dataset_iterator.html " + "for full DatasetIterator docs." + ) + + return getattr(self._base_dataset_pipeline, name) + else: + return super().__getattr__(name) diff --git a/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py new file mode 100644 index 0000000000000..d82abe11959d5 --- /dev/null +++ b/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py @@ -0,0 +1,241 @@ +import copy +import logging +import time +import threading +from typing import ( + List, + Dict, + Optional, + Iterator, + Tuple, + Union, + TYPE_CHECKING, +) + +import ray + +from ray.data.dataset_iterator import DatasetIterator +from ray.data.block import Block, DataBatch, BlockMetadata +from ray.data.context import DatasetContext +from ray.data._internal.execution.streaming_executor import StreamingExecutor +from ray.data._internal.execution.legacy_compat import ( + execute_to_legacy_bundle_iterator, +) +from ray.data._internal.block_batching import batch_block_refs +from ray.data._internal.execution.operators.output_splitter import OutputSplitter +from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle +from ray.data._internal.stats import DatasetStats +from ray.types import ObjectRef +from ray.util.debug import log_once +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +if TYPE_CHECKING: + import pyarrow + from ray.data import Dataset + +logger = logging.getLogger(__name__) + + +BLOCKED_CLIENT_WARN_TIMEOUT = 30 + + +class StreamSplitDatasetIterator(DatasetIterator): + """Implements a collection of iterators over a shared data stream.""" + + @staticmethod + def create( + base_dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ) -> List["StreamSplitDatasetIterator"]: + """Create a split iterator from the given base Dataset and options. + + See also: `Dataset.streaming_split`. + """ + ctx = DatasetContext.get_current() + + # To avoid deadlock, the concurrency on this actor must be set to at least `n`. + coord_actor = SplitCoordinator.options( + max_concurrency=n, + scheduling_strategy=NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ), + ).remote(ctx, base_dataset, n, equal, locality_hints) + + return [ + StreamSplitDatasetIterator(base_dataset, coord_actor, i) for i in range(n) + ] + + def __init__( + self, + base_dataset: "Dataset", + coord_actor: ray.actor.ActorHandle, + output_split_idx: int, + ): + self._base_dataset = base_dataset + self._coord_actor = coord_actor + self._output_split_idx = output_split_idx + + def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: + + def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: + cur_epoch = ray.get( + self._coord_actor.start_epoch.remote(self._output_split_idx) + ) + future: ObjectRef[ + Optional[ObjectRef[Block]] + ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx) + while True: + block_ref: Optional[Tuple[ObjectRef[Block], BlockMetadata]] = ray.get(future) + if not block_ref: + break + else: + future = self._coord_actor.get.remote( + cur_epoch, self._output_split_idx + ) + yield block_ref + + return gen_blocks(), None + + def stats(self) -> str: + """Implements DatasetIterator.""" + return self._base_dataset.stats() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + """Implements DatasetIterator.""" + return self._base_dataset.schema() + + +@ray.remote(num_cpus=0) +class SplitCoordinator: + """Coordinator actor for routing blocks to output splits. + + This actor runs a streaming executor locally on its main thread. Clients can + retrieve results via actor calls running on other threads. + """ + + def __init__( + self, + ctx: DatasetContext, + dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ): + # Automatically set locality with output to the specified location hints. + if locality_hints: + ctx.execution_options.locality_with_output = locality_hints + logger.info(f"Auto configuring locality_with_output={locality_hints}") + + DatasetContext._set_current(ctx) + self._base_dataset = dataset + self._n = n + self._equal = equal + self._locality_hints = locality_hints + self._lock = threading.RLock() + + # Guarded by self._lock. + self._next_bundle: Dict[int, RefBundle] = {} + self._unfinished_clients_in_epoch = n + self._cur_epoch = -1 + + def gen_epochs(): + while True: + executor = StreamingExecutor(copy.deepcopy(ctx.execution_options)) + + def add_split_op(dag): + return OutputSplitter(dag, n, equal, locality_hints) + + output_iterator = execute_to_legacy_bundle_iterator( + executor, + dataset._plan, + True, + dataset._plan._dataset_uuid, + dag_rewrite=add_split_op, + ) + yield output_iterator + + self._next_epoch = gen_epochs() + self._output_iterator = None + + def start_epoch(self, split_idx: int) -> str: + """Called to start an epoch. + + Returns: + UUID for the epoch, which must be used when accessing results via get(). + """ + + # Wait for all clients to arrive at the barrier before starting a new epoch. + epoch_id = self._barrier(split_idx) + return epoch_id + + def get(self, epoch_id: int, output_split_idx: int) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]: + """Blocking get operation. + + This is intended to be called concurrently from multiple clients. + """ + + if epoch_id != self._cur_epoch: + raise ValueError( + "Invalid iterator: the datastream has moved on to another epoch." + ) + + try: + # Ensure there is at least one bundle. + with self._lock: + if output_split_idx in self._next_bundle: + next_bundle = self._next_bundle[output_split_idx] + else: + next_bundle = None + + # Fetch next bundle if needed. + if next_bundle is None: + # This is a BLOCKING call, so do it outside the lock. + next_bundle = self._output_iterator.get_next(output_split_idx) + + bundle = next_bundle.blocks.pop() + + # Accumulate any remaining blocks in next_bundle map as needed. + with self._lock: + self._next_bundle[output_split_idx] = next_bundle + if not next_bundle.blocks: + del self._next_bundle[output_split_idx] + + return bundle + except StopIteration: + return None + + def _barrier(self, split_idx: int) -> int: + """Arrive and block until the start of the given epoch.""" + + # Decrement and await all clients to arrive here. + with self._lock: + starting_epoch = self._cur_epoch + self._unfinished_clients_in_epoch -= 1 + + start_time = time.time() + while ( + self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0 + ): + if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT: + if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"): + logger.warning( + f"StreamSplitDatasetIterator(epoch={starting_epoch}, " + f"split={split_idx}) blocked waiting on other clients " + f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All " + "clients must read from the DatasetIterator splits at " + "the same time. This warning will not be printed again " + "for this epoch." + ) + time.sleep(0.1) + + # Advance to the next epoch. + with self._lock: + if self._cur_epoch == starting_epoch: + self._cur_epoch += 1 + self._unfinished_clients_in_epoch = self._n + self._output_iterator = next(self._next_epoch) + + assert self._output_iterator is not None + return starting_epoch + 1 From 1b7f1b902af4b8ac39513e824b50d898b2f8604d Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 21:58:23 -0700 Subject: [PATCH 24/75] update Signed-off-by: amogkam --- .../_internal/block_batching/interfaces.py | 1 - .../_internal/block_batching/iter_batches.py | 27 +--------- .../dataset_iterator/dataset_iterator_impl.py | 11 +++-- .../pipelined_dataset_iterator.py | 10 ++-- .../stream_split_dataset_iterator.py | 20 +++++--- python/ray/data/_internal/stats.py | 49 +++++++++++++++++-- python/ray/data/dataset_iterator.py | 1 + .../tests/block_batching/test_iter_batches.py | 24 +++------ 8 files changed, 80 insertions(+), 63 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 5165b272b4b40..f925d42ca695c 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -68,7 +68,6 @@ class Batch: batch_idx: int data: DataBatch - logical_batch: LogicalBatch class BlockPrefetcher: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 2c0d68a2ea351..9a610808cc60d 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -19,7 +19,6 @@ WaitBlockPrefetcher, ) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetStats if sys.version_info >= (3, 7): @@ -67,8 +66,7 @@ def iter_batches( 2. Perform the necessary batch slicing to construct full batches. 3. Format the batches to the provided batch format. 4. Apply the collate function - 5. Trace deallocation and eagerly clear block references if necessary. - 6. Fetch outputs from the threadpool, maintaining order of the batches. + 5. Fetch outputs from the threadpool, maintaining order of the batches. Args: block_refs: An iterator over block object references and their corresponding @@ -115,8 +113,6 @@ def iter_batches( else: prefetcher = WaitBlockPrefetcher() - eager_free = clear_block_after_read and context.eager_free - def _async_iter_batches( block_refs: Iterator[ObjectRef[Block]], ) -> Iterator[DataBatch]: @@ -151,10 +147,7 @@ def _async_iter_batches( num_threadpool_workers=prefetch_batches, ) - # Step 5: Trace deallocation - batch_iter = _trace_deallocation(batch_iter, eager_free=eager_free) - - # Step 6: Restore original order. + # Step 5: Restore original order. batch_iter: Iterator[Batch] = _restore_from_original_order(batch_iter) for batch in batch_iter: @@ -517,22 +510,6 @@ def _collate( yield batch -def _trace_deallocation( - batch_iter: Iterator[Batch], eager_free: bool -) -> Iterator[Batch]: - """Trace deallocation of the underlying block references for each batch. - - Args: - batch_iter: An iterator over batches. - eager_free: Whether to eagerly free the object reference from the object store. - """ - for batch in batch_iter: - block_refs = batch.logical_batch.block_refs - for block_ref in block_refs: - trace_deallocation(block_ref, loc="iter_batches", free=eager_free) - yield batch - - def _restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` diff --git a/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py b/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py index e76d7236bba7e..ce888ea97a97b 100644 --- a/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py +++ b/python/ray/data/_internal/dataset_iterator/dataset_iterator_impl.py @@ -1,5 +1,4 @@ from typing import TYPE_CHECKING, Optional, Union, Iterator, Tuple -import time import warnings from ray.types import ObjectRef @@ -23,13 +22,17 @@ def __init__( def __repr__(self) -> str: return f"DatasetIterator({self._base_dataset})" - - def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: + + def _to_block_iterator( + self, + ) -> Tuple[ + Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats] + ]: ds = self._base_dataset block_iterator, stats, executor = ds._plan.execute_to_iterator() ds._current_executor = executor return block_iterator, stats - + def stats(self) -> str: return self._base_dataset.stats() diff --git a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py index 7e687b6d3d72a..f776827571886 100644 --- a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py +++ b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING,Optional, Union, Iterator, Tuple +from typing import TYPE_CHECKING, Optional, Union, Iterator, Tuple import warnings from ray.types import ObjectRef @@ -28,8 +28,12 @@ def _get_next_dataset(self) -> "DatasetPipeline": ds = next(self._epoch_iterator) return ds - - def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: + + def _to_block_iterator( + self, + ) -> Tuple[ + Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats] + ]: ds = self._get_next_dataset() return ds.iterator()._to_block_iterator() diff --git a/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py index d82abe11959d5..205439d361f83 100644 --- a/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py +++ b/python/ray/data/_internal/dataset_iterator/stream_split_dataset_iterator.py @@ -15,13 +15,12 @@ import ray from ray.data.dataset_iterator import DatasetIterator -from ray.data.block import Block, DataBatch, BlockMetadata +from ray.data.block import Block, BlockMetadata from ray.data.context import DatasetContext from ray.data._internal.execution.streaming_executor import StreamingExecutor from ray.data._internal.execution.legacy_compat import ( execute_to_legacy_bundle_iterator, ) -from ray.data._internal.block_batching import batch_block_refs from ray.data._internal.execution.operators.output_splitter import OutputSplitter from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle from ray.data._internal.stats import DatasetStats @@ -77,8 +76,11 @@ def __init__( self._coord_actor = coord_actor self._output_split_idx = output_split_idx - def _to_block_iterator(self) -> Tuple[Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats]]: - + def _to_block_iterator( + self, + ) -> Tuple[ + Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats] + ]: def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: cur_epoch = ray.get( self._coord_actor.start_epoch.remote(self._output_split_idx) @@ -87,7 +89,9 @@ def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: Optional[ObjectRef[Block]] ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx) while True: - block_ref: Optional[Tuple[ObjectRef[Block], BlockMetadata]] = ray.get(future) + block_ref: Optional[Tuple[ObjectRef[Block], BlockMetadata]] = ray.get( + future + ) if not block_ref: break else: @@ -95,7 +99,7 @@ def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: cur_epoch, self._output_split_idx ) yield block_ref - + return gen_blocks(), None def stats(self) -> str: @@ -170,7 +174,9 @@ def start_epoch(self, split_idx: int) -> str: epoch_id = self._barrier(split_idx) return epoch_id - def get(self, epoch_id: int, output_split_idx: int) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]: + def get( + self, epoch_id: int, output_split_idx: int + ) -> Optional[Tuple[ObjectRef[Block], BlockMetadata]]: """Blocking get operation. This is intended to be called concurrently from multiple clients. diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index d18aba9357db5..8450e6c0578d4 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -34,6 +34,9 @@ class Timer: def __init__(self): self._value: float = 0 + self._min: float = float("inf") + self._max: float = 0 + self._total_count: float = 0 @contextmanager def timer(self) -> None: @@ -41,14 +44,28 @@ def timer(self) -> None: try: yield finally: - self._value += time.perf_counter() - time_start + self.add(time.perf_counter() - time_start) def add(self, value: float) -> None: self._value += value + if value < self._min: + self._min = value + if value > self._max: + self._max = value + self._total_count += 1 def get(self) -> float: return self._value + def min(self) -> float: + return self._min + + def max(self) -> float: + return self._max + + def avg(self) -> float: + return self._value / self._total_count + class _DatasetStatsBuilder: """Helper class for building dataset stats. @@ -675,10 +692,32 @@ def to_string(self) -> str: self.iter_unknown_location ) out += "* Batch iteration time breakdown:\n" - out += " * In ray.get(): {}\n".format(fmt(self.get_time.get())) - out += " * In batch creation: {}\n".format(fmt(self.next_time.get())) - out += " * In batch formatting: {}\n".format(fmt(self.format_time.get())) - out += " * In collate_fn: {}\n".format(fmt(self.collate_time.get())) + out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( + fmt(self.get_time.min()), + fmt(self.get_time.max()), + fmt(self.get_time.avg()), + fmt(self.get_time.get()), + ) + out += " * In batch creation: {} min, {} max, {} avg, {} total\n".format( + fmt(self.next_time.min()), + fmt(self.next_time.max()), + fmt(self.next_time.avg()), + fmt(self.next_time.get()), + ) + out += ( + " * In batch formatting: {} min, {} max, {} avg, {} total\n".format( + fmt(self.format_time.min()), + fmt(self.format_time.max()), + fmt(self.format_time.avg()), + fmt(self.format_time.get()), + ) + ) + out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format( + fmt(self.collate_time.min()), + fmt(self.collate_time.max()), + fmt(self.collate_time.avg()), + fmt(self.collate_time.get()), + ) return out diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 7b9f3315a33e6..57117ebff53bf 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -182,6 +182,7 @@ def drop_metadata(block_iterator): shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, prefetch_batches=prefetch_batches, + clear_block_after_read=True, ) stats.iter_total_s.add(time.perf_counter() - time_start) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index ef8da38fa52d4..dda98015e2f7d 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -2,7 +2,6 @@ import pytest import time from typing import Iterator, List, Tuple -from unittest.mock import patch import numpy as np import pandas as pd @@ -25,7 +24,6 @@ _construct_batch_from_logical_batch, _format_batches, _collate, - _trace_deallocation, _restore_from_original_order, ) @@ -246,7 +244,7 @@ def test_construct_batch_from_logical_batch(ray_start_regular_shared, block_size @pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) def test_format_batches(ray_start_regular_shared, batch_format): batches = [ - Batch(i, ray.get(data[0]), None) + Batch(i, ray.get(data[0])) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] batch_iter = _format_batches(batches, batch_format=batch_format) @@ -267,7 +265,7 @@ def collate_fn(batch): return pa.table({"bar": [1] * 2}) batches = [ - Batch(i, ray.get(data[0]), None) + Batch(i, ray.get(data[0])) for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) ] batch_iter = _collate(batches, collate_fn=collate_fn) @@ -277,22 +275,12 @@ def collate_fn(batch): assert batch.data == pa.table({"bar": [1] * 2}) -@patch.object(ray.data._internal.block_batching.iter_batches, "trace_deallocation") -@pytest.mark.parametrize("eager_free", [True, False]) -def test_trace_deallocation(mock, eager_free): - batches = [Batch(0, 0, LogicalBatch(0, [0], 0, None, 1))] - batch_iter = _trace_deallocation(iter(batches), eager_free=eager_free) - # Test that the underlying batch is not modified. - assert next(batch_iter) == batches[0] - mock.assert_called_once_with(0, loc="iter_batches", free=eager_free) - - def test_restore_from_original_order(): base_iterator = [ - Batch(1, None, None), - Batch(0, None, None), - Batch(3, None, None), - Batch(2, None, None), + Batch(1, None), + Batch(0, None), + Batch(3, None), + Batch(2, None), ] ordered = list(_restore_from_original_order(iter(base_iterator))) From 8fee289bac3359c3ebc8245da5bed20adb803438 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 22:01:00 -0700 Subject: [PATCH 25/75] update Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 9a610808cc60d..1da9d540e2560 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -464,7 +464,7 @@ def _construct_batch_from_logical_batch( batch = BlockAccessor.for_block(batch) batch = batch.slice(0, batch.num_rows(), copy=True) - yield Batch(logical_batch.batch_idx, batch, logical_batch) + yield Batch(logical_batch.batch_idx, batch) def _format_batches( From d4843449861637b37ca84054960bd81bb742c3fe Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 22:19:04 -0700 Subject: [PATCH 26/75] release test Signed-off-by: amogkam --- .../dataset/iter_tensor_batches_benchmark.py | 31 +++++++++++++++++-- .../dataset/multi_node_benchmark_compute.yaml | 15 +++++++++ release/release_tests.yaml | 16 ++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 release/nightly_tests/dataset/multi_node_benchmark_compute.yaml diff --git a/release/nightly_tests/dataset/iter_tensor_batches_benchmark.py b/release/nightly_tests/dataset/iter_tensor_batches_benchmark.py index f3e3e1320c5bd..667ce3ba25bb3 100644 --- a/release/nightly_tests/dataset/iter_tensor_batches_benchmark.py +++ b/release/nightly_tests/dataset/iter_tensor_batches_benchmark.py @@ -1,3 +1,4 @@ +import argparse from typing import Optional, Union, List import ray @@ -55,9 +56,9 @@ def to_tf( return ds -def run_iter_tensor_batches_benchmark(benchmark: Benchmark): +def run_iter_tensor_batches_benchmark(benchmark: Benchmark, data_size_gb: int): ds = ray.data.read_images( - "s3://anonymous@air-example-data-2/1G-image-data-synthetic-raw" + f"s3://anonymous@air-example-data-2/{data_size_gb}G-image-data-synthetic-raw" ).cache() # Repartition both to align the block sizes so we can zip them. @@ -102,6 +103,18 @@ def run_iter_tensor_batches_benchmark(benchmark: Benchmark): batch_size=batch_size, ) + prefetch_batches = [1, 10] + # Test with varying prefetching for iter_torch_batches() + for prefetch_batch in prefetch_batches: + test_name = f"iter-torch-batches-prefetch-{32}-{prefetch_batches}" + benchmark.run( + test_name, + iter_torch_batches, + ds=ds, + batch_size=32, + prefetch_batches=prefetch_batch, + ) + # Test with varying batch sizes and shuffle for iter_torch_batches() and to_tf(). for batch_size in batch_sizes: for shuffle_buffer_size in [batch_size, 2 * batch_size]: @@ -128,8 +141,20 @@ def run_iter_tensor_batches_benchmark(benchmark: Benchmark): if __name__ == "__main__": ray.init() + parser = argparse.ArgumentParser( + description="Helper script to upload files to S3 bucket" + ) + parser.add_argument( + "--data-size-gb", + choices=[1, 10], + type=int, + help="The data size to use for the dataset.", + ) + + args = parser.parse_args() + benchmark = Benchmark("iter-tensor-batches") - run_iter_tensor_batches_benchmark(benchmark) + run_iter_tensor_batches_benchmark(benchmark, args.data_size_gb) benchmark.write_result() diff --git a/release/nightly_tests/dataset/multi_node_benchmark_compute.yaml b/release/nightly_tests/dataset/multi_node_benchmark_compute.yaml new file mode 100644 index 0000000000000..9634146e9a00d --- /dev/null +++ b/release/nightly_tests/dataset/multi_node_benchmark_compute.yaml @@ -0,0 +1,15 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 0 + +head_node_type: + name: head_node + instance_type: m5.4xlarge + +worker_node_types: + - name: worker_node + instance_type: m5.4xlarge + max_workers: 3 + min_workers: 3 + use_spot: false diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 19a7439d363ac..5a9d6c79328c3 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -3913,6 +3913,22 @@ timeout: 2400 script: python iter_tensor_batches_benchmark.py +- name: iter_tensor_batches_benchmark_multi_node + group: data-tests + working_dir: nightly_tests/dataset + + frequency: nightly + team: data + cluster: + cluster_env: app_config.yaml + cluster_compute: multi_node_benchmark_compute_yaml + + run: + # Expect the benchmark to finish around 30 minutes. + timeout: 2400 + script: python iter_tensor_batches_benchmark.py --data-size-gb=10 + + - name: iter_batches_benchmark_single_node group: data-tests From fb99ec2024dfb95a9d050eb67aa53b464f814e07 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 22:28:09 -0700 Subject: [PATCH 27/75] update Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 8 ++++---- python/ray/data/_internal/stats.py | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 1da9d540e2560..0fa9021fac381 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -403,7 +403,7 @@ def _resolve_logical_batch( misses += current_miss unknowns += current_unknown - with stats.iter_get_s.timer() if stats else nullcontext(): + with stats.iter_get_s.thread_timer() if stats else nullcontext(): logical_batch.resolve() yield logical_batch @@ -433,7 +433,7 @@ def _construct_batch_from_logical_batch( """ for logical_batch in resolved_logical_batch_iter: - with stats.iter_create_batch_s.timer() if stats else nullcontext(): + with stats.iter_create_batch_s.thread_timer() if stats else nullcontext(): output = DelegatingBlockBuilder() slice_indices = [[0, None] for _ in range(len(logical_batch.blocks))] if logical_batch.starting_block_idx > 0: @@ -483,7 +483,7 @@ def _format_batches( An iterator over batch index and the formatted batch. """ for batch in block_iter: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): + with stats.iter_format_batch_s.thread_timer() if stats else nullcontext(): formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) @@ -505,7 +505,7 @@ def _collate( stats: An optional stats object to record collation time. """ for batch in batch_iter: - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + with stats.iter_collate_batch_s.thread_timer() if stats else nullcontext(): batch.data = collate_fn(batch.data) yield batch diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 8450e6c0578d4..a82f80f8aff36 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -46,6 +46,14 @@ def timer(self) -> None: finally: self.add(time.perf_counter() - time_start) + @contextmanager + def thread_timer(self) -> None: + time_start = time.thread_time() + try: + yield + finally: + self.add(time.thread_time() - time_start) + def add(self, value: float) -> None: self._value += value if value < self._min: From e6e429a593aedcb261212e88e123aed9254c12ea Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 22:36:00 -0700 Subject: [PATCH 28/75] lock Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index a82f80f8aff36..0ceec105bdb69 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import time from contextlib import contextmanager +import threading from typing import Dict, List, Optional, Set, Tuple, Union, Any import numpy as np @@ -38,6 +39,8 @@ def __init__(self): self._max: float = 0 self._total_count: float = 0 + self.lock = threading.Lock() + @contextmanager def timer(self) -> None: time_start = time.perf_counter() @@ -52,7 +55,8 @@ def thread_timer(self) -> None: try: yield finally: - self.add(time.thread_time() - time_start) + with self.lock(): + self.add(time.thread_time() - time_start) def add(self, value: float) -> None: self._value += value From 5522174b8478422490965bb85842477f495aced0 Mon Sep 17 00:00:00 2001 From: amogkam Date: Wed, 22 Mar 2023 22:37:09 -0700 Subject: [PATCH 29/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 0ceec105bdb69..a1ec0cff12ecd 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -55,7 +55,7 @@ def thread_timer(self) -> None: try: yield finally: - with self.lock(): + with self.lock: self.add(time.thread_time() - time_start) def add(self, value: float) -> None: From fab714ce8f3c2516a65e0e0abb4c8a1774d0f7c3 Mon Sep 17 00:00:00 2001 From: amogkam Date: Thu, 23 Mar 2023 23:06:57 -0700 Subject: [PATCH 30/75] update Signed-off-by: amogkam --- .../block_batching/block_batching.py | 186 ++--------- .../_internal/block_batching/interfaces.py | 54 +--- .../_internal/block_batching/iter_batches.py | 295 ++++-------------- .../ray/data/_internal/block_batching/util.py | 185 ++++++++++- .../block_batching/test_block_batching.py | 54 +--- .../tests/block_batching/test_interfaces.py | 28 -- .../tests/block_batching/test_iter_batches.py | 211 ++----------- .../data/tests/block_batching/test_util.py | 96 +++++- 8 files changed, 387 insertions(+), 722 deletions(-) delete mode 100644 python/ray/data/tests/block_batching/test_interfaces.py diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 2c4b380eb6a1d..db6a719a03a45 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -6,14 +6,17 @@ import ray from ray.data._internal.block_batching.interfaces import BlockPrefetcher from ray.data._internal.block_batching.util import ( - _make_async_gen, + resolve_block_refs, + blocks_to_batches, + format_batches, + collate, + extract_data_from_batch, + make_async_gen, WaitBlockPrefetcher, ActorBlockPrefetcher, ) -from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.stats import DatasetPipelineStats, DatasetStats -from ray.data._internal.memory_tracing import trace_deallocation -from ray.data.block import Block, BlockAccessor, DataBatch +from ray.data.block import Block, DataBatch from ray.data.context import DatasetContext from ray.types import ObjectRef @@ -98,15 +101,21 @@ def batch_block_refs( else: prefetcher = WaitBlockPrefetcher() - block_iter = _resolve_blocks( - _prefetch_blocks( - block_ref_iter=block_refs, - prefetcher=prefetcher, - stats=stats, - num_blocks_to_prefetch=prefetch_blocks, - clear_block_after_read=clear_block_after_read, + eager_free = clear_block_after_read and DatasetContext.get_current().eager_free + + block_iter = resolve_block_refs( + map( + list, + _prefetch_blocks( + block_ref_iter=block_refs, + prefetcher=prefetcher, + stats=stats, + num_blocks_to_prefetch=prefetch_blocks, + clear_block_after_read=clear_block_after_read, + ), ), stats=stats, + eager_free=eager_free, ) yield from batch_blocks( @@ -144,8 +153,8 @@ def batch_blocks( """ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]: - batch_iter = _format_batches( - _blocks_to_batches( + batch_iter = format_batches( + blocks_to_batches( block_iter=base_iterator, stats=stats, batch_size=batch_size, @@ -159,16 +168,13 @@ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]: ) if collate_fn is not None: + batch_iter = collate(batch_iter, collate_fn=collate_fn) - def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: - for batch in iterator: - yield collate_fn(batch) - - batch_iter = batch_fn_iter(batch_iter) + batch_iter = extract_data_from_batch(batch_iter) yield from batch_iter if prefetch_batches > 0: - batch_iter = _make_async_gen( + batch_iter = make_async_gen( blocks, fn=_iterator_fn, num_workers=prefetch_batches ) else: @@ -180,58 +186,10 @@ def batch_fn_iter(iterator: Iterator[DataBatch]) -> Iterator[DataBatch]: yield formatted_batch -def _resolve_blocks( - block_ref_iter: Iterator[ObjectRef[Block]], - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, -) -> Iterator[Block]: - """Given an iterator of unresolved blocks (as Ray object references), returns an - iterator of resolved blocks. - - The length of the returned iterator may be less than the length of the original - if any of the references in the original iterator are None. - - Args: - block_ref_iter: An iterator over block object references. - stats: Dataset stats object used to store block fetching time. - - Returns: - An iterator over resolved blocks. - """ - - hit = 0 - miss = 0 - unknown = 0 - for block_ref in block_ref_iter: - if block_ref is not None: - stats_timer = stats.iter_get_s.timer() if stats else nullcontext() - # Count the number of blocks that we hit locally or miss (so have to - # fetch from remote node). This is to measure the effectiveness of - # prefetch. - loc = ray.experimental.get_object_locations([block_ref]) - nodes = loc[block_ref]["node_ids"] - if nodes: - current = ray.get_runtime_context().get_node_id() - if current in nodes: - hit += 1 - else: - miss += 1 - else: - unknown += 1 - with stats_timer: - block = ray.get(block_ref) - yield block - - if stats: - stats.iter_blocks_local = hit - stats.iter_blocks_remote = miss - stats.iter_unknown_location = unknown - - def _prefetch_blocks( block_ref_iter: Iterator[ObjectRef[Block]], prefetcher: BlockPrefetcher, num_blocks_to_prefetch: int, - clear_block_after_read: bool = False, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[ObjectRef[Block]]: """Given an iterable of Block Object References, returns an iterator @@ -242,21 +200,11 @@ def _prefetch_blocks( block_ref_iter: An iterator over block object references. num_blocks_to_prefetch: The number of blocks to prefetch ahead of the current block during the scan. - clear_block_after_read: Whether to clear the block from object store - manually (i.e. without waiting for Python's automatic GC) after it - is read. Doing so will reclaim memory faster and hence reduce the - memory footprint. However, the caller has to ensure the safety, i.e. - the block will never be accessed again. stats: Dataset stats object used to store block wait time. """ - eager_free = clear_block_after_read and DatasetContext.get_current().eager_free - if num_blocks_to_prefetch == 0: for block_ref in block_ref_iter: yield block_ref - trace_deallocation( - block_ref, "block_batching._prefetch_blocks", free=eager_free - ) window_size = num_blocks_to_prefetch # Create the initial set of blocks to prefetch. @@ -275,87 +223,3 @@ def _prefetch_blocks( except StopIteration: pass yield block_ref - trace_deallocation( - block_ref, "block_batching._prefetch_blocks", free=eager_free - ) - - -def _blocks_to_batches( - block_iter: Iterator[Block], - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, - batch_size: Optional[int] = None, - drop_last: bool = False, - shuffle_buffer_min_size: Optional[int] = None, - shuffle_seed: Optional[int] = None, - ensure_copy: bool = False, -) -> Iterator[Block]: - """Given an iterator over blocks, returns an iterator over blocks - of the appropriate bacth size. - - If the shuffling configurations are specified, then the - output blocks contain shuffled data. - - Args: - block_iter: An iterator over blocks. - stats: Dataset stats object used to store block batching time. - batch_size: Record batch size, or None to let the system pick. - drop_last: Whether to drop the last batch if it's incomplete. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - - Returns: - An iterator over blocks of the given size that are potentially shuffled. - """ - if shuffle_buffer_min_size is not None: - batcher = ShufflingBatcher( - batch_size=batch_size, - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - else: - batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) - - def get_iter_next_batch_s_timer(): - return stats.iter_next_batch_s.timer() if stats else nullcontext() - - for block in block_iter: - batcher.add(block) - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Signal to the batcher that there are no more blocks to add. - batcher.done_adding() - - # Get any leftover batches in ShufflingBatcher. - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Get any remaining data. - if not drop_last and batcher.has_any(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - -def _format_batches( - block_iter: Iterator[Block], - batch_format: str, - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, -) -> Iterator[DataBatch]: - """Given an iterator of blocks, returns an iterator of formatted batches. - - Args: - block_iter: An iterator over blocks. - batch_format: The batch format to use. - - Returns: - An iterator over formatted batches. - """ - for block in block_iter: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - batch = BlockAccessor.for_block(block).to_batch_format(batch_format) - yield batch diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 65332656ddc02..e2672863d9ace 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,60 +1,10 @@ from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any -import ray from ray.types import ObjectRef from ray.data.block import Block -@dataclass -class LogicalBatch: - """A logical "batch" of data. - - This is not a fully created batch, but rather a conceptual batch - consisting of unresolved Block Object references. - - Attributes: - bundle_idx: The global index of this bundle so that downstream operations can - maintain ordering. - block_refs: The list of block object references for this batch. - blocks: The resolved blocks for this batch. This attribute can only be accessed - after calling `.resolve()` - starting_block_idx: The index of the first block where this batch starts. - ending_block_idx: The index of the last block where this batch ends. This can - also be None, meaning the entirety of the last block is included in this - batch. If this value is None, this allows us to eagerly clear the last - block in this batch after reading, since the last block is not included in - any other batches. - num_rows: The number of rows in this batch. This should be equivalent to the - provided batch size, except for the final batch. - """ - - batch_idx: int - block_refs: List[ObjectRef[Block]] - starting_block_idx: int - ending_block_idx: Optional[int] - num_rows: int - - def __post_init__(self): - self._resolved = False - - def resolve(self): - """Resolves the block_refs in this LogicalBatch.""" - if self._resolved: - return - self._resolved = True - self._blocks = ray.get(self.block_refs) - - @property - def blocks(self) -> List[Block]: - if not self._resolved: - raise RuntimeError( - "The resolved blocks for this logical batch can only be " - "accessed after calling `resolve`." - ) - return self._blocks - - @dataclass class Batch: """A batch of data. @@ -63,12 +13,10 @@ class Batch: batch_idx: The global index of this batch so that downstream operations can maintain ordering. data: The batch of data. - logical_batch: The logical batch that was used to create this batch. """ batch_idx: int data: Any - logical_batch: LogicalBatch class BlockPrefetcher: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 94b027ca08878..f956a61daebc0 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,44 +1,37 @@ +import collections import heapq -import random -from typing import Any, Callable, Iterator, List, Optional, Tuple +from typing import Iterator, List, Optional, Tuple from ray.types import ObjectRef -from ray.data.block import Block, BlockMetadata, BlockAccessor, DataBatch +from ray.data.block import Block, BlockMetadata from ray.data._internal.block_batching.interfaces import ( Batch, - LogicalBatch, BlockPrefetcher, ) -from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.memory_tracing import trace_deallocation def bundle_block_refs_to_logical_batches( block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], batch_size: Optional[int], drop_last: bool = False, -) -> Iterator[LogicalBatch]: +) -> Iterator[List[ObjectRef[Block]]]: """Given an iterator of block object references, and their corresponding metadata, - bundles the block object references into groups of the provided `batch_size`. + bundles the block object references into groups of at least `batch_size`. + - The output iterator returns an iterator over LogicalBatch objects. This function does not do any slicing or creation of actual batch objects. """ batch_buffer: List[ObjectRef[Block]] = [] buffer_size = 0 - starting_index = 0 - - global_index = 0 - num_rows_in_last_block = 0 + original_batch_size = batch_size if batch_size is None: - for block_ref, metadata in block_ref_iterator: - yield LogicalBatch(global_index, [block_ref], 0, None, metadata.num_rows) - global_index += 1 + for block_ref, _ in block_ref_iterator: + yield [block_ref] else: while True: - if buffer_size < batch_size: + if buffer_size < batch_size or buffer_size <= 0: # Pull next block from iterator if current buffer is not enough to fill # a batch. try: @@ -47,246 +40,66 @@ def bundle_block_refs_to_logical_batches( break batch_buffer.append(block_ref) buffer_size += metadata.num_rows - num_rows_in_last_block = metadata.num_rows - if buffer_size == batch_size: - # If equal to batch size, then yield the full buffer. - yield LogicalBatch( - global_index, batch_buffer, starting_index, None, buffer_size - ) + else: + # If equal to or greater than batch size, then yield the full buffer. + yield batch_buffer + carryover_to_next_batch = buffer_size - batch_size batch_buffer = [] buffer_size = 0 - starting_index = 0 - num_rows_in_last_block = 0 - global_index += 1 - - if buffer_size > batch_size: - # If current buffer is greater than batch size, then yield part of the - # buffer, and carryover the remainder to the next batch. - num_rows_to_leave_behind = buffer_size - batch_size - ending_index = num_rows_in_last_block - num_rows_to_leave_behind - assert ending_index > 0, ending_index - yield LogicalBatch( - global_index, batch_buffer, starting_index, ending_index, batch_size - ) - global_index += 1 - # Carryover to next batch. - batch_buffer = [batch_buffer[-1]] - buffer_size = num_rows_to_leave_behind - starting_index = ending_index - - # Yield any leftover batches if necessary. - if buffer_size > 0 and not drop_last: - assert buffer_size < batch_size - yield LogicalBatch( - global_index, batch_buffer, starting_index, None, buffer_size - ) - global_index += 1 - - -def local_shuffle_logical_batches( - logical_batch_iterator: Iterator[LogicalBatch], - shuffle_buffer_min_size: int, - shuffle_seed: Optional[int] = None, -) -> Iterator[LogicalBatch]: - """Shuffles the logical batch iterator using a buffer of the provided size.""" - - if shuffle_seed is not None: - random.seed(shuffle_seed) - - shuffle_buffer: List[LogicalBatch] = [] - shuffle_buffer_size = 0 - global_counter = 0 - - for logical_batch in logical_batch_iterator: - shuffle_buffer.append(logical_batch) - shuffle_buffer_size += logical_batch.num_rows - - while shuffle_buffer_size >= shuffle_buffer_min_size: - output_batch = shuffle_buffer.pop( - random.randint(0, len(shuffle_buffer) - 1) - ) - output_batch.batch_idx = global_counter - yield output_batch - shuffle_buffer_size -= output_batch.num_rows - global_counter += 1 - - # Yield any leftover. - while len(shuffle_buffer) > 0: - output_batch = shuffle_buffer.pop(random.randint(0, len(shuffle_buffer) - 1)) - output_batch.batch_idx = global_counter - yield output_batch - global_counter += 1 + assert carryover_to_next_batch >= 0 + if carryover_to_next_batch == 0: + # Reset the + batch_size = original_batch_size + elif carryover_to_next_batch > 0: + # Carryover remainder to next batch so we don't prefetch too much. + # Example: 4 blocks with 2 rows each. Batch size of 3. + # Batch 1: Yield 2 blocks (4 total rows) + # Batch 2: Only yield 1 additional block since 1 row from the + # previous yield should be included in this batch. + batch_size = original_batch_size - carryover_to_next_batch + + # Yield any leftover batches if necessary. + assert buffer_size < original_batch_size + if buffer_size > 0 and not drop_last: + yield batch_buffer def prefetch_batches_locally( - logical_batch_iter: Iterator[LogicalBatch], + block_ref_iter: Iterator[List[ObjectRef[Block]]], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, -) -> Iterator[LogicalBatch]: - """Given an iterator of logical batches, returns an iterator over the same logical - batches, while prefetching `num_batches_to_prefetch` batches in advance. +) -> Iterator[List[ObjectRef[Block]]]: + """Given an iterator of batched block references, returns an iterator over the same + block references while prefetching `num_batches_to_prefetch` batches in advance. Args: - logical_batch_iter: An iterator over logical batches. + block_ref_iter: An iterator over batched block references. prefetcher: The prefetcher to use. num_batches_to_prefetch: The number of batches to prefetch ahead of the current batch during the scan. """ - def get_next_batches() -> Iterator[List[LogicalBatch]]: - """Return lists of logical batches corresponding to `num_batches_to_prefetch`""" - next_batches = [] - while True: - try: - next_batches.append(next(logical_batch_iter)) - if len(next_batches) == num_batches_to_prefetch: - yield next_batches - next_batches = [] - except StopIteration: - break - - if len(next_batches) > 0: - yield next_batches - - # Fetch the initial set of batches. - batch_iterator = get_next_batches() - try: - batches = next(batch_iterator) - except StopIteration: - return - - block_refs = [block_ref for batch in batches for block_ref in batch.block_refs] - prefetcher.prefetch_blocks(block_refs) - - for next_batches in batch_iterator: - # Prefetch the next batches. - block_refs = [ - block_ref for batch in next_batches for block_ref in batch.block_refs - ] - prefetcher.prefetch_blocks(block_refs) - - for batch in batches: - yield batch - - batches = next_batches - - # Yield the final set of batches. - for batch in batches: - yield batch - - -def resolve_logical_batch( - logical_batch_iter: Iterator[LogicalBatch], -) -> Iterator[LogicalBatch]: - """Resolves the block references for each logical batch.""" - for logical_batch in logical_batch_iter: - logical_batch.resolve() - yield logical_batch - - -def construct_batch_from_logical_batch( - resolved_logical_batch_iter: Iterator[LogicalBatch], - ensure_copy: bool = False, -) -> Iterator[Batch]: - """Given an iterator over logical batches, returns an iterator over actual - constructed batches. - - Args: - resolved_logical_batch_iter: An iterator over resolved logical batches. - stats: Dataset stats object used to store block batching time. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - - Returns: - An iterator over batch index and batches of the given size. - """ - - for logical_batch in resolved_logical_batch_iter: - output = DelegatingBlockBuilder() - slice_indices = [[0, None] for _ in range(len(logical_batch.blocks))] - if logical_batch.starting_block_idx > 0: - slice_indices[0][0] = logical_batch.starting_block_idx - if logical_batch.ending_block_idx is not None: - slice_indices[-1][1] = logical_batch.ending_block_idx - - for i, block in enumerate(logical_batch.blocks): - accessor = BlockAccessor.for_block(block) - slice_index = slice_indices[i] - output.add_block( - accessor.slice( - slice_index[0], - slice_index[1] - if slice_index[1] is not None - else accessor.num_rows(), - copy=False, - ) + sliding_window = collections.deque(maxlen=num_batches_to_prefetch) + # Create and fetch the initial window. + for _ in range(num_batches_to_prefetch): + try: + sliding_window.append(next(block_ref_iter)) + except StopIteration: + break + prefetcher.prefetch_blocks( + [block_ref for batch in list(sliding_window) for block_ref in batch] + ) + + while sliding_window: + batch = sliding_window.popleft() + try: + sliding_window.append(next(block_ref_iter)) + prefetcher.prefetch_blocks( + [block_ref for batch in list(sliding_window) for block_ref in batch] ) - - batch = output.build() - assert len(batch) == logical_batch.num_rows, ( - len(batch), - logical_batch.num_rows, - ) - if ensure_copy: - # Need to ensure that the batch is a fresh copy. - batch = BlockAccessor.for_block(batch) - batch = batch.slice(0, batch.num_rows(), copy=True) - - yield Batch(logical_batch.batch_idx, batch, logical_batch) - - -def format_batches( - block_iter: Iterator[Batch], - batch_format: Optional[str], -) -> Iterator[Batch]: - """Given an iterator of blocks, returns an iterator of formatted batches. - - Args: - block_iter: An iterator over blocks. - batch_format: The batch format to use. - stats: An optional stats object to record formatting times. - - Returns: - An iterator over batch index and the formatted batch. - """ - for batch in block_iter: - formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) - batch.data = formatted_batch - yield batch - - -def collate( - batch_iter: Iterator[Batch], - collate_fn: Optional[Callable[[DataBatch], Any]], -) -> Iterator[Batch]: - """Returns an iterator with the provided collate_fn applied to items of the batch - iterator. - - Args: - batch_iter: An iterator over formatted batches. - stats: An optional stats object to record collation time. - """ - for batch in batch_iter: - batch.data = collate_fn(batch.data) - yield batch - - -def trace_deallocation_for_batch( - batch_iter: Iterator[Batch], eager_free: bool -) -> Iterator[Batch]: - """Trace deallocation of the underlying block references for each batch. - - Args: - batch_iter: An iterator over batches. - eager_free: Whether to eagerly free the object reference from the object store. - """ - for batch in batch_iter: - block_refs = batch.logical_batch.block_refs - for block_ref in block_refs: - trace_deallocation(block_ref, loc="iter_batches", free=eager_free) + except StopIteration: + pass yield batch diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index ab6c4fe5e0dc1..0b5b24a2d7367 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -1,13 +1,17 @@ import logging import queue import threading -from typing import Callable, Iterator, TypeVar +import sys +from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union import ray from ray.actor import ActorHandle from ray.types import ObjectRef -from ray.data.block import Block -from ray.data._internal.block_batching.interfaces import BlockPrefetcher +from ray.data.block import Block, BlockAccessor, DataBatch +from ray.data._internal.batcher import Batcher, ShufflingBatcher +from ray.data._internal.block_batching.interfaces import Batch, BlockPrefetcher +from ray.data._internal.memory_tracing import trace_deallocation +from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy T = TypeVar("T") @@ -15,8 +19,181 @@ logger = logging.getLogger(__name__) +if sys.version_info >= (3, 7): + from contextlib import nullcontext +else: + from contextlib import contextmanager -def _make_async_gen( + @contextmanager + def nullcontext(enter_result=None): + yield enter_result + + +def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: + """Given a list of object references, returns how many are already on the local + node, how many require fetching from another node, and how many have unknown + locations.""" + current_node_id = ray.get_runtime_context().get_node_id() + + locs = ray.experimental.get_object_locations(refs) + nodes: List[List[str]] = [loc["node_ids"] for loc in locs.values()] + hits = sum(current_node_id in node_ids for node_ids in nodes) + unknowns = sum(1 for node_ids in nodes if not node_ids) + misses = len(nodes) - hits - unknowns + return hits, misses, unknowns + + +def resolve_block_refs( + block_ref_iter: Iterator[List[ObjectRef[Block]]], + eager_free: bool = False, + stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, +) -> Iterator[Block]: + """Resolves the block references for each logical batch. + + Args: + block_ref_iter: An iterator over block object references. + eager_free: Whether to eagerly free the object reference from the object store. + stats: An optional stats object to recording block hits and misses. + """ + hits = 0 + misses = 0 + unknowns = 0 + + for block_refs in block_ref_iter: + current_hit, current_miss, current_unknown = _calculate_ref_hits(block_refs) + hits += current_hit + misses += current_miss + unknowns += current_unknown + + with stats.iter_get_s.thread_timer() if stats else nullcontext(): + blocks = ray.get(block_refs) + for block_ref in block_refs: + trace_deallocation(block_ref, loc="iter_batches", free=eager_free) + for block in blocks: + yield block + + if stats: + stats.iter_blocks_local += hits + stats.iter_blocks_remote += misses + stats.iter_unknown_location += unknowns + + +def blocks_to_batches( + block_iter: Iterator[Block], + stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, + batch_size: Optional[int] = None, + drop_last: bool = False, + shuffle_buffer_min_size: Optional[int] = None, + shuffle_seed: Optional[int] = None, + ensure_copy: bool = False, +) -> Iterator[Batch]: + """Given an iterator over blocks, returns an iterator over blocks + of the appropriate bacth size. + + If the shuffling configurations are specified, then the + output blocks contain shuffled data. + + Args: + block_iter: An iterator over blocks. + stats: Dataset stats object used to store block batching time. + batch_size: Record batch size, or None to let the system pick. + drop_last: Whether to drop the last batch if it's incomplete. + shuffle_buffer_min_size: If non-None, the data will be randomly shuffled + using a local in-memory shuffle buffer, and this value will serve as the + minimum number of rows that must be in the local in-memory shuffle buffer in + order to yield a batch. + shuffle_seed: The seed to use for the local random shuffle. + ensure_copy: Whether batches are always copied from the underlying base + blocks (not zero-copy views). + + Returns: + An iterator over blocks of the given size that are potentially shuffled. + """ + if shuffle_buffer_min_size is not None: + batcher = ShufflingBatcher( + batch_size=batch_size, + shuffle_buffer_min_size=shuffle_buffer_min_size, + shuffle_seed=shuffle_seed, + ) + else: + batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) + + def get_iter_next_batch_s_timer(): + return stats.iter_next_batch_s.timer() if stats else nullcontext() + + global_counter = 0 + + for block in block_iter: + batcher.add(block) + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield Batch(global_counter, batch) + global_counter += 1 + + # Signal to the batcher that there are no more blocks to add. + batcher.done_adding() + + # Get any leftover batches in ShufflingBatcher. + while batcher.has_batch(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield Batch(global_counter, batch) + global_counter += 1 + + # Get any remaining data. + if not drop_last and batcher.has_any(): + with get_iter_next_batch_s_timer(): + batch = batcher.next_batch() + yield Batch(global_counter, batch) + global_counter += 1 + + +def format_batches( + block_iter: Iterator[Batch], + batch_format: Optional[str], + stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, +) -> Iterator[Batch]: + """Given an iterator of blocks, returns an iterator of formatted batches. + + Args: + block_iter: An iterator over blocks. + batch_format: The batch format to use. + stats: An optional stats object to record formatting times. + + Returns: + An iterator over batch index and the formatted batch. + """ + for batch in block_iter: + with stats.iter_format_batch_s.timer() if stats else nullcontext(): + formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) + batch.data = formatted_batch + yield batch + + +def collate( + batch_iter: Iterator[Batch], + collate_fn: Optional[Callable[[DataBatch], Any]], +) -> Iterator[Batch]: + """Returns an iterator with the provided collate_fn applied to items of the batch + iterator. + + Args: + batch_iter: An iterator over formatted batches. + """ + for batch in batch_iter: + batch.data = collate_fn(batch.data) + yield batch + + +def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]: + for batch in batch_iter: + yield batch.data + + +def make_async_gen( base_iterator: Iterator[T], fn: Callable[[Iterator[T]], Iterator[U]], num_workers: int = 1, diff --git a/python/ray/data/tests/block_batching/test_block_batching.py b/python/ray/data/tests/block_batching/test_block_batching.py index 357eff91cc42f..7351ac165af2a 100644 --- a/python/ray/data/tests/block_batching/test_block_batching.py +++ b/python/ray/data/tests/block_batching/test_block_batching.py @@ -3,8 +3,6 @@ from typing import List from unittest import mock -import numpy as np -import pandas as pd import pyarrow as pa from ray.data.block import Block @@ -13,8 +11,6 @@ batch_block_refs, batch_blocks, _prefetch_blocks, - _blocks_to_batches, - _format_batches, ) @@ -39,9 +35,9 @@ def test_batch_block_refs(): def test_batch_blocks(): with mock.patch( - "ray.data._internal.block_batching.block_batching._blocks_to_batches" + "ray.data._internal.block_batching.block_batching.blocks_to_batches" ) as mock_batch, mock.patch( - "ray.data._internal.block_batching.block_batching._format_batches" + "ray.data._internal.block_batching.block_batching.format_batches" ) as mock_format: block_iter = block_generator(2, 2) batch_iter = batch_blocks(block_iter) @@ -78,50 +74,6 @@ def prefetch_blocks(self, blocks: List[Block]): assert all(len(window) == num_blocks_to_prefetch for window in windows) -@pytest.mark.parametrize("block_size", [1, 10]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_blocks_to_batches(block_size, drop_last): - num_blocks = 5 - block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) - - batch_size = 3 - batch_iter = _blocks_to_batches( - block_iter, batch_size=batch_size, drop_last=drop_last - ) - - if drop_last: - for batch in batch_iter: - assert len(batch) == batch_size - else: - full_batches = 0 - leftover_batches = 0 - - dataset_size = block_size * num_blocks - for batch in batch_iter: - if len(batch) == batch_size: - full_batches += 1 - if len(batch) == (dataset_size % batch_size): - leftover_batches += 1 - - assert leftover_batches == 1 - assert full_batches == (dataset_size // batch_size) - - -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) -def test_format_batches(batch_format): - block_iter = block_generator(num_rows=2, num_blocks=2) - batch_iter = _format_batches(block_iter, batch_format=batch_format) - - for batch in batch_iter: - if batch_format == "pandas": - assert isinstance(batch, pd.DataFrame) - elif batch_format == "arrow": - assert isinstance(batch, pa.Table) - elif batch_format == "numpy": - assert isinstance(batch, dict) - assert isinstance(batch["foo"], np.ndarray) - - # Test for 3 cases # 1. Batch size is less than block size # 2. Batch size is more than block size @@ -136,7 +88,7 @@ def sleep_batch_format(batch_iter, *args, **kwargs): yield batch with mock.patch( - "ray.data._internal.block_batching.block_batching._format_batches", + "ray.data._internal.block_batching.util.format_batches", sleep_batch_format, ): batch_iter = batch_blocks( diff --git a/python/ray/data/tests/block_batching/test_interfaces.py b/python/ray/data/tests/block_batching/test_interfaces.py deleted file mode 100644 index 0f061596a4ea7..0000000000000 --- a/python/ray/data/tests/block_batching/test_interfaces.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -import ray -from ray.data._internal.block_batching.interfaces import LogicalBatch - - -def test_logical_batch_resolves_blocks(ray_start_regular_shared): - block_refs = [ray.put(1), ray.put(2)] - logical_batch = LogicalBatch( - batch_idx=0, - block_refs=block_refs, - starting_block_idx=0, - ending_block_idx=None, - num_rows=2, - ) - - # Blocks should not be accessible before calling resolve(). - with pytest.raises(RuntimeError): - logical_batch.blocks - - logical_batch.resolve() - assert logical_batch.blocks == [1, 2] - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 21b0282430868..062da6d392025 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,38 +1,25 @@ -from copy import copy import pytest from typing import Iterator, List, Tuple -from unittest.mock import patch -import numpy as np -import pandas as pd import pyarrow as pa -import ray -from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata from ray.data._internal.block_batching.interfaces import ( Batch, - LogicalBatch, BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( bundle_block_refs_to_logical_batches, - local_shuffle_logical_batches, prefetch_batches_locally, - resolve_logical_batch, - construct_batch_from_logical_batch, - format_batches, - collate, - trace_deallocation_for_batch, restore_from_original_order, ) def block_generator( num_rows: int, num_blocks: int -) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: +) -> Iterator[Tuple[Block, BlockMetadata]]: for i in range(num_blocks): - yield ray.put(pa.table({"foo": [i] * num_rows})), BlockMetadata( + yield pa.table({"foo": [i] * num_rows}), BlockMetadata( num_rows=num_rows, size_bytes=0, schema=None, @@ -41,25 +28,7 @@ def block_generator( ) -def logical_batch_generator( - num_rows: int, num_blocks: int, batch_size: int = None -) -> Iterator[LogicalBatch]: - logical_batch_iter = bundle_block_refs_to_logical_batches( - block_generator(num_rows=num_rows, num_blocks=num_blocks), batch_size=batch_size - ) - return logical_batch_iter - - -def resolved_logical_batch_generator( - num_rows: int, num_blocks: int, batch_size: int = None -): - logical_batch_iter = logical_batch_generator(num_rows, num_blocks, batch_size) - for logical_batch in logical_batch_iter: - logical_batch.resolve() - yield logical_batch - - -def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): +def test_bundle_block_refs_to_logical_batches(): # Case 1: `batch_size` is None. num_blocks = 4 num_rows_per_block = 2 @@ -71,13 +40,14 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): ) logical_batches = list(logical_batch_iter) assert logical_batches == [ - LogicalBatch(0, [block_refs[0][0]], 0, None, num_rows_per_block), - LogicalBatch(1, [block_refs[1][0]], 0, None, num_rows_per_block), - LogicalBatch(2, [block_refs[2][0]], 0, None, num_rows_per_block), - LogicalBatch(3, [block_refs[3][0]], 0, None, num_rows_per_block), + [block_refs[0][0]], + [block_refs[1][0]], + [block_refs[2][0]], + [block_refs[3][0]], ] # Case 2: Multiple batches in a block (`batch_size` is 1). + # There should be no overlap. num_blocks = 2 num_rows_per_block = 2 batch_size = 1 @@ -87,12 +57,7 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): iter(block_refs), batch_size=batch_size ) logical_batches = list(logical_batch_iter) - assert logical_batches == [ - LogicalBatch(0, [block_refs[0][0]], 0, 1, batch_size), - LogicalBatch(1, [block_refs[0][0]], 1, None, batch_size), - LogicalBatch(2, [block_refs[1][0]], 0, 1, batch_size), - LogicalBatch(3, [block_refs[1][0]], 1, None, batch_size), - ] + assert logical_batches == [[block_refs[0][0]], [block_refs[1][0]]] # Case 3: Multiple blocks in a batch (`batch_size` is 2) num_blocks = 4 @@ -105,8 +70,8 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): ) logical_batches = list(logical_batch_iter) assert logical_batches == [ - LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, None, batch_size), - LogicalBatch(1, [block_refs[2][0], block_refs[3][0]], 0, None, batch_size), + [block_refs[0][0], block_refs[1][0]], + [block_refs[2][0], block_refs[3][0]], ] # Case 4: Batches overlap across multiple blocks unevenly @@ -120,9 +85,9 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): ) logical_batches = list(logical_batch_iter) assert logical_batches == [ - LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, 1, batch_size), - LogicalBatch(1, [block_refs[1][0], block_refs[2][0]], 1, None, batch_size), - LogicalBatch(2, [block_refs[3][0]], 0, None, 2), # Leftover block. + [block_refs[0][0], block_refs[1][0]], + [block_refs[2][0]], + [block_refs[3][0]], # Leftover block. ] # Case 5: Batches overlap across multiple blocks unevenly, dropping the last @@ -137,53 +102,13 @@ def test_bundle_block_refs_to_logical_batches(ray_start_regular_shared): ) logical_batches = list(logical_batch_iter) assert logical_batches == [ - LogicalBatch(0, [block_refs[0][0], block_refs[1][0]], 0, 1, batch_size), - LogicalBatch(1, [block_refs[1][0], block_refs[2][0]], 1, None, batch_size), + [block_refs[0][0], block_refs[1][0]], + [block_refs[2][0]], ] -def test_local_shuffle_logical_batches(ray_start_regular_shared): - # Case 1: Shuffle buffer min size is smaller than a batch. - # In this case, there is effectively no shuffling since the buffer - # never contains more than 1 batch. - shuffle_seed = 42 - num_blocks = 4 - num_rows_per_block = 2 - shuffle_buffer_min_size = 1 - logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) - shuffled_batches = list( - local_shuffle_logical_batches( - iter(logical_batches), - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - ) - assert shuffled_batches == logical_batches - - # Case 2: Shuffle buffer min size is greater than a batch. - shuffle_seed = 42 - num_blocks = 4 - num_rows_per_block = 1 - shuffle_buffer_min_size = 2 - logical_batches = list(logical_batch_generator(num_rows_per_block, num_blocks)) - shuffled_batches = list( - local_shuffle_logical_batches( - iter(logical_batches), - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - ) - - expected_output_ordering = [0, 1, 3, 2] - expected_output = [copy(logical_batches[i]) for i in expected_output_ordering] - for i in range(len(expected_output)): - expected_output[i].batch_idx = i - - assert shuffled_batches == expected_output - - @pytest.mark.parametrize("num_batches_to_prefetch", [1, 2]) -def test_prefetch_batches_locally(ray_start_regular_shared, num_batches_to_prefetch): +def test_prefetch_batches_locally(num_batches_to_prefetch): class DummyPrefetcher(BlockPrefetcher): def __init__(self): self.windows = [] @@ -191,106 +116,32 @@ def __init__(self): def prefetch_blocks(self, blocks: List[Block]): self.windows.append(blocks) - num_blocks = 10 + num_batches = 10 prefetcher = DummyPrefetcher() - logical_batches = list(logical_batch_generator(1, num_blocks)) - prefetch_batch_iter = prefetch_batches_locally( - iter(logical_batches), + block_iter = iter([[i] for i in range(num_batches)]) + prefetch_block_iter = prefetch_batches_locally( + block_iter, prefetcher=prefetcher, num_batches_to_prefetch=num_batches_to_prefetch, ) - # Test that we are actually prefetching. - # We should prefetch a new set of batches after the current set - # finishes. - sets_prefetched = 1 - output_batches = [] - for i, batch in enumerate(prefetch_batch_iter): - if i % num_batches_to_prefetch == 0: - # If all the batches are already prefetched, then skip the check. - if not sets_prefetched * num_batches_to_prefetch >= len(logical_batches): - assert len(prefetcher.windows) == sets_prefetched + 1 - sets_prefetched = len(prefetcher.windows) - output_batches.append(batch) + batch_count = 1 + for _ in prefetch_block_iter: + batch_count += 1 + if batch_count < num_batches: + # Test that we are actually prefetching. + assert len(prefetcher.windows) == batch_count windows = prefetcher.windows assert all(len(window) == num_batches_to_prefetch for window in windows) - # Check that the output iterator is the same as the input iterator. - assert output_batches == logical_batches - - -def test_resolve_logical_batches(ray_start_regular_shared): - logical_batches = list(logical_batch_generator(1, 1)) - resolved_iter = resolve_logical_batch(iter(logical_batches)) - assert next(resolved_iter).blocks == ray.get(logical_batches[0].block_refs) - - -@pytest.mark.parametrize("block_size", [1, 10]) -def test_construct_batch_from_logical_batch(ray_start_regular_shared, block_size): - num_blocks = 5 - batch_size = 3 - logical_batches = list( - resolved_logical_batch_generator(block_size, num_blocks, batch_size=batch_size) - ) - - created_batches = list(construct_batch_from_logical_batch(iter(logical_batches))) - - for i, batch in enumerate(created_batches): - assert i == batch.batch_idx - assert len(batch.data) == logical_batches[i].num_rows - - -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) -def test_format_batches(ray_start_regular_shared, batch_format): - batches = [ - Batch(i, ray.get(data[0]), None) - for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) - ] - batch_iter = format_batches(batches, batch_format=batch_format) - - for i, batch in enumerate(batch_iter): - assert batch.batch_idx == i - if batch_format == "pandas": - assert isinstance(batch.data, pd.DataFrame) - elif batch_format == "arrow": - assert isinstance(batch.data, pa.Table) - elif batch_format == "numpy": - assert isinstance(batch.data, dict) - assert isinstance(batch.data["foo"], np.ndarray) - - -def test_collate(ray_start_regular_shared): - def collate_fn(batch): - return pa.table({"bar": [1] * 2}) - - batches = [ - Batch(i, ray.get(data[0]), None) - for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) - ] - batch_iter = collate(batches, collate_fn=collate_fn) - - for i, batch in enumerate(batch_iter): - assert batch.batch_idx == i - assert batch.data == pa.table({"bar": [1] * 2}) - - -@patch.object(ray.data._internal.block_batching.iter_batches, "trace_deallocation") -@pytest.mark.parametrize("eager_free", [True, False]) -def test_trace_deallocation(mock, eager_free): - batches = [Batch(0, 0, LogicalBatch(0, [0], 0, None, 1))] - batch_iter = trace_deallocation_for_batch(iter(batches), eager_free=eager_free) - # Test that the underlying batch is not modified. - assert next(batch_iter) == batches[0] - mock.assert_called_once_with(0, loc="iter_batches", free=eager_free) - def test_restore_from_original_order(): base_iterator = [ - Batch(1, None, None), - Batch(0, None, None), - Batch(3, None, None), - Batch(2, None, None), + Batch(1, None), + Batch(0, None), + Batch(3, None), + Batch(2, None), ] ordered = list(restore_from_original_order(iter(base_iterator))) diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 47140686de01b..551df1cae632d 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -1,7 +1,95 @@ import pytest import time -from ray.data._internal.block_batching.util import _make_async_gen +import numpy as np +import pandas as pd +import pyarrow as pa + +import ray +from ray.data._internal.block_batching.util import ( + make_async_gen, + blocks_to_batches, + format_batches, + collate, + resolve_block_refs, +) +from ray.data._internal.block_batching.interfaces import Batch + + +def block_generator(num_rows: int, num_blocks: int): + for _ in range(num_blocks): + yield pa.table({"foo": [1] * num_rows}) + + +def test_resolve_block_refs(ray_start_regular_shared): + block_refs = [[ray.put(0), ray.put(1)], [ray.put(2)]] + + resolved_iter = resolve_block_refs(iter(block_refs)) + assert list(resolved_iter) == [0, 1, 2] + + +@pytest.mark.parametrize("block_size", [1, 10]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_blocks_to_batches(block_size, drop_last): + num_blocks = 5 + block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + + batch_size = 3 + batch_iter = list( + blocks_to_batches(block_iter, batch_size=batch_size, drop_last=drop_last) + ) + + if drop_last: + for batch in batch_iter: + assert len(batch.data) == batch_size + else: + full_batches = 0 + leftover_batches = 0 + + dataset_size = block_size * num_blocks + for batch in batch_iter: + if len(batch.data) == batch_size: + full_batches += 1 + if len(batch.data) == (dataset_size % batch_size): + leftover_batches += 1 + + assert leftover_batches == 1 + assert full_batches == (dataset_size // batch_size) + + assert [batch.batch_idx for batch in batch_iter] == list(range(len(batch_iter))) + + +@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) +def test_format_batches(batch_format): + block_iter = block_generator(num_rows=2, num_blocks=2) + batch_iter = (Batch(i, block) for i, block in enumerate(block_iter)) + batch_iter = list(format_batches(batch_iter, batch_format=batch_format)) + + for batch in batch_iter: + if batch_format == "pandas": + assert isinstance(batch.data, pd.DataFrame) + elif batch_format == "arrow": + assert isinstance(batch.data, pa.Table) + elif batch_format == "numpy": + assert isinstance(batch.data, dict) + assert isinstance(batch.data["foo"], np.ndarray) + + assert [batch.batch_idx for batch in batch_iter] == list(range(len(batch_iter))) + + +def test_collate(): + def collate_fn(batch): + return pa.table({"bar": [1] * 2}) + + batches = [ + Batch(i, data) + for i, data in enumerate(block_generator(num_rows=2, num_blocks=2)) + ] + batch_iter = collate(batches, collate_fn=collate_fn) + + for i, batch in enumerate(batch_iter): + assert batch.batch_idx == i + assert batch.data == pa.table({"bar": [1] * 2}) def test_make_async_gen_fail(): @@ -11,7 +99,7 @@ def test_make_async_gen_fail(): def gen(base_iterator): raise ValueError("Fail") - iterator = _make_async_gen(base_iterator=iter([1]), fn=gen) + iterator = make_async_gen(base_iterator=iter([1]), fn=gen) with pytest.raises(ValueError) as e: for _ in iterator: @@ -34,7 +122,7 @@ def sleep_udf(item): time.sleep(3) return item - iterator = _make_async_gen( + iterator = make_async_gen( base_iterator=iter(range(num_items)), fn=gen, num_workers=1 ) @@ -66,7 +154,7 @@ def sleep_udf(item): return item # All 5 items should be fetched concurrently. - iterator = _make_async_gen( + iterator = make_async_gen( base_iterator=iter(range(num_items)), fn=gen, num_workers=5 ) From ccc00c2c8f6537708ca0ff4424ab80b0f134b129 Mon Sep 17 00:00:00 2001 From: amogkam Date: Thu, 23 Mar 2023 23:14:49 -0700 Subject: [PATCH 31/75] update Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f956a61daebc0..4e12c0d9aaf0a 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,6 +1,5 @@ import collections -import heapq -from typing import Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata @@ -113,12 +112,14 @@ def restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: (base_iterator) must be present. """ next_index_required = 0 - buffer: List[Batch] = [] + buffer: Dict[int, Batch] = {} for batch in batch_iter: - heapq.heappush(buffer, (batch.batch_idx, batch)) - if buffer[0][0] == next_index_required: - yield heapq.heappop(buffer)[1] + assert batch.batch_idx not in buffer + buffer[batch.batch_idx] = batch + while next_index_required in buffer: + yield buffer.pop(next_index_required) next_index_required += 1 - while len(buffer) > 0: - yield heapq.heappop(buffer)[1] + while next_index_required in buffer: + yield buffer.pop(next_index_required) + next_index_required += 1 From 64638a5a5898997e165a5341a23e22ce9940c91f Mon Sep 17 00:00:00 2001 From: amogkam Date: Thu, 23 Mar 2023 23:17:18 -0700 Subject: [PATCH 32/75] update Signed-off-by: amogkam --- .../ray/data/_internal/block_batching/iter_batches.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 4e12c0d9aaf0a..03de07f134b3e 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -44,19 +44,17 @@ def bundle_block_refs_to_logical_batches( # If equal to or greater than batch size, then yield the full buffer. yield batch_buffer carryover_to_next_batch = buffer_size - batch_size + # Reset the batch size. + batch_size = original_batch_size batch_buffer = [] buffer_size = 0 - assert carryover_to_next_batch >= 0 - if carryover_to_next_batch == 0: - # Reset the - batch_size = original_batch_size - elif carryover_to_next_batch > 0: + if carryover_to_next_batch > 0: # Carryover remainder to next batch so we don't prefetch too much. # Example: 4 blocks with 2 rows each. Batch size of 3. # Batch 1: Yield 2 blocks (4 total rows) # Batch 2: Only yield 1 additional block since 1 row from the # previous yield should be included in this batch. - batch_size = original_batch_size - carryover_to_next_batch + batch_size = batch_size - carryover_to_next_batch # Yield any leftover batches if necessary. assert buffer_size < original_batch_size From e6cdb0647e595be1360fda6322ee21d7d87ccdc9 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 13:45:24 -0700 Subject: [PATCH 33/75] address comments Signed-off-by: amogkam --- .../_internal/block_batching/interfaces.py | 24 +++- .../_internal/block_batching/iter_batches.py | 117 ++++++++-------- .../ray/data/_internal/block_batching/util.py | 15 +- .../tests/block_batching/test_iter_batches.py | 128 +++++------------- 4 files changed, 115 insertions(+), 169 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index e2672863d9ace..f474b84c7971b 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass -from typing import Any +from typing import Any, NamedTuple from ray.types import ObjectRef -from ray.data.block import Block +from ray.data.block import Block, DataBatch -@dataclass -class Batch: - """A batch of data. +class Batch(NamedTuple): + """A batch of data with a corresponding index. Attributes: batch_idx: The global index of this batch so that downstream operations can @@ -15,6 +13,20 @@ class Batch: data: The batch of data. """ + batch_idx: int + data: DataBatch + + +class CollatedBatch(NamedTuple): + """A batch of collated data with a corresponding index. + + Attributes: + batch_idx: The global index of this batch so that downstream operations can + maintain ordering. + data: The batch of data which is the output of a user provided collate_fn + Therefore, the type of this data can be Any. + """ + batch_idx: int data: Any diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 03de07f134b3e..91681aa02de27 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -8,64 +8,40 @@ BlockPrefetcher, ) +""" +The algorithm uses both pipeline parallelism and data parallelism: -def bundle_block_refs_to_logical_batches( - block_ref_iterator: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], - batch_size: Optional[int], - drop_last: bool = False, -) -> Iterator[List[ObjectRef[Block]]]: - """Given an iterator of block object references, and their corresponding metadata, - bundles the block object references into groups of at least `batch_size`. +If prefetch_batches=2, these are all the batches in flight: +[User thread] trains on Batch 0 + - [Fetch thread] Batch 1 in output queue + - [Worker thread 1] Batch 2 formatting + collating + - [Worker thread 2] Batch 3 formatting + collating + - [Raylet] Batches 4 + 5 fetched to local object store memory +At any point in time there are prefetch_batches+1 batches in local heap memory +And the next set of prefetch_batches in local object store memory - This function does not do any slicing or creation of actual batch objects. - """ - batch_buffer: List[ObjectRef[Block]] = [] - buffer_size = 0 - original_batch_size = batch_size - - if batch_size is None: - for block_ref, _ in block_ref_iterator: - yield [block_ref] - else: - while True: - if buffer_size < batch_size or buffer_size <= 0: - # Pull next block from iterator if current buffer is not enough to fill - # a batch. - try: - block_ref, metadata = next(block_ref_iterator) - except StopIteration: - break - batch_buffer.append(block_ref) - buffer_size += metadata.num_rows - - else: - # If equal to or greater than batch size, then yield the full buffer. - yield batch_buffer - carryover_to_next_batch = buffer_size - batch_size - # Reset the batch size. - batch_size = original_batch_size - batch_buffer = [] - buffer_size = 0 - if carryover_to_next_batch > 0: - # Carryover remainder to next batch so we don't prefetch too much. - # Example: 4 blocks with 2 rows each. Batch size of 3. - # Batch 1: Yield 2 blocks (4 total rows) - # Batch 2: Only yield 1 additional block since 1 row from the - # previous yield should be included in this batch. - batch_size = batch_size - carryover_to_next_batch - - # Yield any leftover batches if necessary. - assert buffer_size < original_batch_size - if buffer_size > 0 and not drop_last: - yield batch_buffer +The actual steps are as follows: + +In a single async thread, do the following: + 1. Trigger Ray local prefetching of `prefetch_batches` worth of block object + references. + 2. Resolve (i.e. call `ray.get()`) on the block references + 3. Perform the necessary batch slicing to construct full batches, possibly + shuffling if necessary. + 4. Then, in a threadpool consisting of `prefetch_batches` threads: + 3. Format the batches to the provided batch format. + 4. Apply the collate function + 5. Fetch outputs from the threadpool, maintaining order of the batches. +""" def prefetch_batches_locally( - block_ref_iter: Iterator[List[ObjectRef[Block]]], + block_ref_iter: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, + batch_size: Optional[int], ) -> Iterator[List[ObjectRef[Block]]]: """Given an iterator of batched block references, returns an iterator over the same block references while prefetching `num_batches_to_prefetch` batches in advance. @@ -75,29 +51,42 @@ def prefetch_batches_locally( prefetcher: The prefetcher to use. num_batches_to_prefetch: The number of batches to prefetch ahead of the current batch during the scan. + batch_size: User specified batch size, or None to let the system pick. """ - sliding_window = collections.deque(maxlen=num_batches_to_prefetch) + sliding_window = collections.deque() + current_window_size = 0 + + if batch_size: + num_rows_to_prefetch = num_batches_to_prefetch * batch_size + # Create and fetch the initial window. - for _ in range(num_batches_to_prefetch): + while True: try: - sliding_window.append(next(block_ref_iter)) + next_block_ref_and_metadata = next(block_ref_iter) + sliding_window.append(next_block_ref_and_metadata) + current_window_size += next_block_ref_and_metadata[1].num_rows except StopIteration: break - prefetcher.prefetch_blocks( - [block_ref for batch in list(sliding_window) for block_ref in batch] - ) + if batch_size and current_window_size >= num_rows_to_prefetch: + break + elif not batch_size and len(sliding_window) >= num_batches_to_prefetch: + break + + prefetcher.prefetch_blocks([block_ref for block_ref, _ in list(sliding_window)]) while sliding_window: - batch = sliding_window.popleft() - try: - sliding_window.append(next(block_ref_iter)) - prefetcher.prefetch_blocks( - [block_ref for batch in list(sliding_window) for block_ref in batch] - ) - except StopIteration: - pass - yield batch + block_ref, metadata = sliding_window.popleft() + current_window_size -= metadata.num_rows + if not batch_size or current_window_size < num_rows_to_prefetch: + try: + sliding_window.append(next(block_ref_iter)) + prefetcher.prefetch_blocks( + [block_ref for block_ref, _ in list(sliding_window)] + ) + except StopIteration: + pass + yield block_ref def restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 0b5b24a2d7367..9488d2f7ad49f 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -9,7 +9,11 @@ from ray.types import ObjectRef from ray.data.block import Block, BlockAccessor, DataBatch from ray.data._internal.batcher import Batcher, ShufflingBatcher -from ray.data._internal.block_batching.interfaces import Batch, BlockPrefetcher +from ray.data._internal.block_batching.interfaces import ( + Batch, + CollatedBatch, + BlockPrefetcher, +) from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -169,14 +173,13 @@ def format_batches( formatted_batch = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) - batch.data = formatted_batch - yield batch + yield Batch(batch.batch_idx, formatted_batch) def collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], -) -> Iterator[Batch]: +) -> Iterator[CollatedBatch]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. @@ -184,8 +187,8 @@ def collate( batch_iter: An iterator over formatted batches. """ for batch in batch_iter: - batch.data = collate_fn(batch.data) - yield batch + collated_batch = collate_fn(batch.data) + yield CollatedBatch(batch.batch_idx, collated_batch) def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]: diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 062da6d392025..6a4afbdffcb92 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -9,7 +9,6 @@ BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( - bundle_block_refs_to_logical_batches, prefetch_batches_locally, restore_from_original_order, ) @@ -28,112 +27,55 @@ def block_generator( ) -def test_bundle_block_refs_to_logical_batches(): - # Case 1: `batch_size` is None. - num_blocks = 4 - num_rows_per_block = 2 - batch_size = None - block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) - block_refs = list(block_iter) - logical_batch_iter = bundle_block_refs_to_logical_batches( - iter(block_refs), batch_size=batch_size - ) - logical_batches = list(logical_batch_iter) - assert logical_batches == [ - [block_refs[0][0]], - [block_refs[1][0]], - [block_refs[2][0]], - [block_refs[3][0]], - ] - - # Case 2: Multiple batches in a block (`batch_size` is 1). - # There should be no overlap. - num_blocks = 2 - num_rows_per_block = 2 - batch_size = 1 - block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) - block_refs = list(block_iter) - logical_batch_iter = bundle_block_refs_to_logical_batches( - iter(block_refs), batch_size=batch_size - ) - logical_batches = list(logical_batch_iter) - assert logical_batches == [[block_refs[0][0]], [block_refs[1][0]]] - - # Case 3: Multiple blocks in a batch (`batch_size` is 2) - num_blocks = 4 - num_rows_per_block = 1 - batch_size = 2 - block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) - block_refs = list(block_iter) - logical_batch_iter = bundle_block_refs_to_logical_batches( - iter(block_refs), batch_size=batch_size - ) - logical_batches = list(logical_batch_iter) - assert logical_batches == [ - [block_refs[0][0], block_refs[1][0]], - [block_refs[2][0], block_refs[3][0]], - ] - - # Case 4: Batches overlap across multiple blocks unevenly - num_blocks = 4 - num_rows_per_block = 2 - batch_size = 3 - block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) - block_refs = list(block_iter) - logical_batch_iter = bundle_block_refs_to_logical_batches( - iter(block_refs), batch_size=batch_size - ) - logical_batches = list(logical_batch_iter) - assert logical_batches == [ - [block_refs[0][0], block_refs[1][0]], - [block_refs[2][0]], - [block_refs[3][0]], # Leftover block. - ] - - # Case 5: Batches overlap across multiple blocks unevenly, dropping the last - # incomplete batch. - num_blocks = 4 - num_rows_per_block = 2 - batch_size = 3 - block_iter = block_generator(num_rows=num_rows_per_block, num_blocks=num_blocks) - block_refs = list(block_iter) - logical_batch_iter = bundle_block_refs_to_logical_batches( - iter(block_refs), batch_size=batch_size, drop_last=True - ) - logical_batches = list(logical_batch_iter) - assert logical_batches == [ - [block_refs[0][0], block_refs[1][0]], - [block_refs[2][0]], - ] - - @pytest.mark.parametrize("num_batches_to_prefetch", [1, 2]) -def test_prefetch_batches_locally(num_batches_to_prefetch): +@pytest.mark.parametrize("batch_size", [None, 1, 4]) +def test_prefetch_batches_locally(num_batches_to_prefetch, batch_size): class DummyPrefetcher(BlockPrefetcher): def __init__(self): self.windows = [] def prefetch_blocks(self, blocks: List[Block]): + if not batch_size: + assert len(blocks) == num_batches_to_prefetch + else: + assert ( + sum(len(block) for block in blocks) + >= batch_size * num_batches_to_prefetch + ) self.windows.append(blocks) - num_batches = 10 + num_blocks = 10 + num_rows = 2 prefetcher = DummyPrefetcher() - block_iter = iter([[i] for i in range(num_batches)]) + blocks = list(block_generator(num_blocks=num_blocks, num_rows=num_rows)) prefetch_block_iter = prefetch_batches_locally( - block_iter, + iter(blocks), prefetcher=prefetcher, num_batches_to_prefetch=num_batches_to_prefetch, + batch_size=batch_size, ) - batch_count = 1 - for _ in prefetch_block_iter: - batch_count += 1 - if batch_count < num_batches: - # Test that we are actually prefetching. - assert len(prefetcher.windows) == batch_count - - windows = prefetcher.windows - assert all(len(window) == num_batches_to_prefetch for window in windows) + block_count = 0 + prefetched_blocks = [] + previous_num_windows = 1 + + for block in prefetch_block_iter: + prefetched_blocks.append(block) + block_count += 1 + remaining_rows = (num_blocks - block_count) * num_rows + if batch_size is None and block_count < num_blocks - num_batches_to_prefetch: + # Test that we are actually prefetching in advance if this is not the last + # block. + assert len(prefetcher.windows) == previous_num_windows + 1 + previous_num_windows = len(prefetcher.windows) + elif batch_size and remaining_rows > batch_size * num_batches_to_prefetch: + # Test that we are actually prefetching in advance if this is not the last + # batch. + assert len(prefetcher.windows) == previous_num_windows + 1 + previous_num_windows = len(prefetcher.windows) + + # Test that original blocks are unchanged. + assert prefetched_blocks == [block for block, metadata, in blocks] def test_restore_from_original_order(): From 0feeb2d4a0a98405254ae1c6b497d79d91a08e3e Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 13:49:46 -0700 Subject: [PATCH 34/75] syntax Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 91681aa02de27..a01f16092db11 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -19,20 +19,20 @@ - [Worker thread 2] Batch 3 formatting + collating - [Raylet] Batches 4 + 5 fetched to local object store memory -At any point in time there are prefetch_batches+1 batches in local heap memory -And the next set of prefetch_batches in local object store memory +At any point in time there are prefetch_batches+1 batches in local heap memory. +And the next set of prefetch_batches in local object store memory. The actual steps are as follows: In a single async thread, do the following: 1. Trigger Ray local prefetching of `prefetch_batches` worth of block object references. - 2. Resolve (i.e. call `ray.get()`) on the block references + 2. Resolve (i.e. call `ray.get()`) on the block references. 3. Perform the necessary batch slicing to construct full batches, possibly shuffling if necessary. 4. Then, in a threadpool consisting of `prefetch_batches` threads: 3. Format the batches to the provided batch format. - 4. Apply the collate function + 4. Apply the collate function. 5. Fetch outputs from the threadpool, maintaining order of the batches. """ From 00a5455559a962048ccb071cfe16ae7611fd70d2 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 14:15:50 -0700 Subject: [PATCH 35/75] update interfaces Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/interfaces.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index f474b84c7971b..67cfd0a6eac36 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,10 +1,12 @@ -from typing import Any, NamedTuple +from dataclasses import dataclass +from typing import Any from ray.types import ObjectRef from ray.data.block import Block, DataBatch -class Batch(NamedTuple): +@dataclass +class Batch: """A batch of data with a corresponding index. Attributes: @@ -17,7 +19,7 @@ class Batch(NamedTuple): data: DataBatch -class CollatedBatch(NamedTuple): +class CollatedBatch(Batch): """A batch of collated data with a corresponding index. Attributes: From 5a309df0179c1202de9f3767eb81820b81597902 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 14:26:15 -0700 Subject: [PATCH 36/75] wip Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 60 +++++++++---------- .../ray/data/_internal/block_batching/util.py | 5 +- .../tests/block_batching/test_iter_batches.py | 2 + 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 741ed0acb6b30..318528e0e92bd 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,4 +1,5 @@ import collections +import sys from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import ray @@ -8,10 +9,19 @@ Batch, BlockPrefetcher, ) -from ray.data._internal.block_batching.util import ActorBlockPrefetcher, WaitBlockPrefetcher, resolve_block_refs, blocks_to_batches, format_batches, collate +from ray.data._internal.block_batching.util import ActorBlockPrefetcher, WaitBlockPrefetcher, resolve_block_refs, blocks_to_batches, format_batches, collate, extract_data_from_batch, make_async_gen from ray.data._internal.stats import DatasetStats from ray.data.context import DatasetContext +if sys.version_info >= (3, 7): + from contextlib import nullcontext +else: + from contextlib import contextmanager + + @contextmanager + def nullcontext(enter_result=None): + yield enter_result + def iter_batches( block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], *, @@ -55,8 +65,8 @@ def iter_batches( 3. Perform the necessary batch slicing to construct full batches, possibly shuffling if necessary. 4. Then, in a threadpool consisting of `prefetch_batches` threads: - 3. Format the batches to the provided batch format. - 4. Apply the collate function. + 1. Format the batches to the provided batch format. + 2. Apply the collate function. 5. Fetch outputs from the threadpool, maintaining order of the batches. Args: @@ -138,24 +148,22 @@ def _async_iter_batches( ) # Step 4: Use a threadpool for formatting and collation. - batch_iter = _batch_in_threadpool( + batch_iter = _format_in_threadpool( batch_iter, stats=stats, batch_format=batch_format, collate_fn=collate_fn, - ensure_copy=ensure_copy, num_threadpool_workers=prefetch_batches, ) # Step 5: Restore original order. batch_iter: Iterator[Batch] = restore_from_original_order(batch_iter) - for batch in batch_iter: - yield batch.data + yield from extract_data_from_batch(batch_iter) # Run everything in a separate thread to not block the main thread when waiting # for streaming results. - async_batch_iter = _make_async_gen( + async_batch_iter = make_async_gen( block_refs, fn=_async_iter_batches, num_workers=1 ) @@ -169,12 +177,11 @@ def _async_iter_batches( yield next_batch -def _batch_in_threadpool( - logical_batch_iterator: Iterator[LogicalBatch], +def _format_in_threadpool( + batch_iter: Iterator[Batch], stats: DatasetStats, batch_format: str = "default", collate_fn: Optional[Callable[[DataBatch], Any]] = None, - ensure_copy: bool = False, num_threadpool_workers: int = 0, ) -> Iterator[Batch]: """Executes the batching, formatting, and collation logic in a threadpool. @@ -188,39 +195,32 @@ def _batch_in_threadpool( select ``pandas.DataFrame`` or "pyarrow" to select ``pyarrow.Table``. Default is "default". collate_fn: A function to apply to each data batch before returning it. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). num_threadpool_workers: The number of threads to use in the threadpool. """ def threadpool_computations( - logical_batch_iter: Iterator[LogicalBatch], + batch_iter: Iterator[Batch], ) -> Iterator[Batch]: - # Step 4.1: Resolve the blocks. - resolved_batch_iter = _resolve_logical_batch(logical_batch_iter, stats=stats) - - # Step 4.2: Slice the blocks to create the batch. - batch_iter = _construct_batch_from_logical_batch( - resolved_batch_iter, ensure_copy=ensure_copy, stats=stats - ) - - # Step 4.3: Format the batches. - formatted_batch_iter = _format_batches( + # Step 4.1: Format the batches. + formatted_batch_iter = format_batches( batch_iter, batch_format=batch_format, stats=stats ) # Step 4.4: Apply the collate function if applicable. if collate_fn is not None: - formatted_batch_iter = _collate( + formatted_batch_iter = collate( formatted_batch_iter, collate_fn=collate_fn, stats=stats ) yield from formatted_batch_iter - return _make_async_gen( - base_iterator=logical_batch_iterator, - fn=threadpool_computations, - num_workers=num_threadpool_workers, - ) + if num_threadpool_workers > 0: + return make_async_gen( + base_iterator=batch_iter, + fn=threadpool_computations, + num_workers=num_threadpool_workers, + ) + else: + return threadpool_computations(batch_iter) def prefetch_batches_locally( block_ref_iter: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 99a01e24e0dfa..bef06c1754769 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -178,6 +178,7 @@ def format_batches( def collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], + stats: Optional[DatasetStats], ) -> Iterator[CollatedBatch]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. @@ -186,7 +187,8 @@ def collate( batch_iter: An iterator over formatted batches. """ for batch in batch_iter: - collated_batch = collate_fn(batch.data) + with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + collated_batch = collate_fn(batch.data) yield CollatedBatch(batch.batch_idx, collated_batch) @@ -194,7 +196,6 @@ def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]: for batch in batch_iter: yield batch.data - def make_async_gen( base_iterator: Iterator[T], fn: Callable[[Iterator[T]], Iterator[U]], diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 373c94f1b3a02..77a89772d493d 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -2,6 +2,7 @@ import time from typing import Iterator, List, Tuple +import pandas as pd import pyarrow as pa from ray.data.block import Block, BlockMetadata @@ -10,6 +11,7 @@ BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( + iter_batches, prefetch_batches_locally, restore_from_original_order, ) From 717f3ca44183d2aedc80bee161c75462a1573b61 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 14:26:47 -0700 Subject: [PATCH 37/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 9488d2f7ad49f..30c3bd6574cda 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -69,7 +69,7 @@ def resolve_block_refs( misses += current_miss unknowns += current_unknown - with stats.iter_get_s.thread_timer() if stats else nullcontext(): + with stats.iter_get_s.timer() if stats else nullcontext(): blocks = ray.get(block_refs) for block_ref in block_refs: trace_deallocation(block_ref, loc="iter_batches", free=eager_free) From 1c6dd853d4968f8f67df38e66eddd8b2569501f7 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 14:28:06 -0700 Subject: [PATCH 38/75] merge conflicts Signed-off-by: amogkam --- .../block_batching/block_batching.py | 94 ------------------- .../ray/data/_internal/block_batching/util.py | 2 +- 2 files changed, 1 insertion(+), 95 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 524260dcdb44c..1f762611b2bb8 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -6,15 +6,12 @@ import ray from ray.data._internal.block_batching.interfaces import BlockPrefetcher from ray.data._internal.block_batching.util import ( -<<<<<<< HEAD -======= resolve_block_refs, blocks_to_batches, format_batches, collate, extract_data_from_batch, make_async_gen, ->>>>>>> 0feeb2d4a0a98405254ae1c6b497d79d91a08e3e WaitBlockPrefetcher, ActorBlockPrefetcher, ) @@ -167,16 +164,12 @@ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]: batch_iter = extract_data_from_batch(batch_iter) yield from batch_iter -<<<<<<< HEAD - batch_iter = _iterator_fn(blocks) -======= if prefetch_batches > 0: batch_iter = make_async_gen( blocks, fn=_iterator_fn, num_workers=prefetch_batches ) else: batch_iter = _iterator_fn(blocks) ->>>>>>> 0feeb2d4a0a98405254ae1c6b497d79d91a08e3e for formatted_batch in batch_iter: user_timer = stats.iter_user_s.timer() if stats else nullcontext() @@ -219,90 +212,3 @@ def _prefetch_blocks( except StopIteration: pass yield block_ref -<<<<<<< HEAD - trace_deallocation( - block_ref, "block_batching._prefetch_blocks", free=eager_free - ) - - -def _blocks_to_batches( - block_iter: Iterator[Block], - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, - batch_size: Optional[int] = None, - drop_last: bool = False, - shuffle_buffer_min_size: Optional[int] = None, - shuffle_seed: Optional[int] = None, - ensure_copy: bool = False, -) -> Iterator[Block]: - """Given an iterator over blocks, returns an iterator over blocks - of the appropriate bacth size. - - If the shuffling configurations are specified, then the - output blocks contain shuffled data. - - Args: - block_iter: An iterator over blocks. - stats: Dataset stats object used to store block batching time. - batch_size: Record batch size, or None to let the system pick. - drop_last: Whether to drop the last batch if it's incomplete. - ensure_copy: Whether batches are always copied from the underlying base - blocks (not zero-copy views). - - Returns: - An iterator over blocks of the given size that are potentially shuffled. - """ - if shuffle_buffer_min_size is not None: - batcher = ShufflingBatcher( - batch_size=batch_size, - shuffle_buffer_min_size=shuffle_buffer_min_size, - shuffle_seed=shuffle_seed, - ) - else: - batcher = Batcher(batch_size=batch_size, ensure_copy=ensure_copy) - - def get_iter_next_batch_s_timer(): - return stats.iter_create_batch_s.timer() if stats else nullcontext() - - for block in block_iter: - batcher.add(block) - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Signal to the batcher that there are no more blocks to add. - batcher.done_adding() - - # Get any leftover batches in ShufflingBatcher. - while batcher.has_batch(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - # Get any remaining data. - if not drop_last and batcher.has_any(): - with get_iter_next_batch_s_timer(): - batch = batcher.next_batch() - yield batch - - -def _format_batches( - block_iter: Iterator[Block], - batch_format: str, - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, -) -> Iterator[DataBatch]: - """Given an iterator of blocks, returns an iterator of formatted batches. - - Args: - block_iter: An iterator over blocks. - batch_format: The batch format to use. - - Returns: - An iterator over formatted batches. - """ - for block in block_iter: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - batch = BlockAccessor.for_block(block).to_batch_format(batch_format) - yield batch -======= ->>>>>>> 0feeb2d4a0a98405254ae1c6b497d79d91a08e3e diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index bef06c1754769..70ae0120dfae7 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -68,7 +68,7 @@ def resolve_block_refs( misses += current_miss unknowns += current_unknown - with stats.iter_get_s.thread_timer() if stats else nullcontext(): + with stats.iter_get_s.timer() if stats else nullcontext(): blocks = ray.get(block_refs) for block_ref in block_refs: trace_deallocation(block_ref, loc="iter_batches", free=eager_free) From ebaa8ebea4ad15a87a237158950c0d2aa15ec57f Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 15:24:34 -0700 Subject: [PATCH 39/75] fix Signed-off-by: amogkam --- .../_internal/block_batching/block_batching.py | 15 ++++++--------- .../_internal/block_batching/iter_batches.py | 4 ++-- .../ray/data/_internal/block_batching/util.py | 17 ++++++++--------- .../ray/data/tests/block_batching/test_util.py | 2 +- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index db6a719a03a45..478ce41f05ae7 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -104,15 +104,12 @@ def batch_block_refs( eager_free = clear_block_after_read and DatasetContext.get_current().eager_free block_iter = resolve_block_refs( - map( - list, - _prefetch_blocks( - block_ref_iter=block_refs, - prefetcher=prefetcher, - stats=stats, - num_blocks_to_prefetch=prefetch_blocks, - clear_block_after_read=clear_block_after_read, - ), + _prefetch_blocks( + block_ref_iter=block_refs, + prefetcher=prefetcher, + stats=stats, + num_blocks_to_prefetch=prefetch_blocks, + clear_block_after_read=clear_block_after_read, ), stats=stats, eager_free=eager_free, diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index a01f16092db11..6540821971215 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,5 +1,5 @@ import collections -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple from ray.types import ObjectRef from ray.data.block import Block, BlockMetadata @@ -42,7 +42,7 @@ def prefetch_batches_locally( prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, batch_size: Optional[int], -) -> Iterator[List[ObjectRef[Block]]]: +) -> Iterator[ObjectRef[Block]]: """Given an iterator of batched block references, returns an iterator over the same block references while prefetching `num_batches_to_prefetch` batches in advance. diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 30c3bd6574cda..21484b13df806 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -48,7 +48,7 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: def resolve_block_refs( - block_ref_iter: Iterator[List[ObjectRef[Block]]], + block_ref_iter: Iterator[ObjectRef[Block]], eager_free: bool = False, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[Block]: @@ -63,18 +63,17 @@ def resolve_block_refs( misses = 0 unknowns = 0 - for block_refs in block_ref_iter: - current_hit, current_miss, current_unknown = _calculate_ref_hits(block_refs) + for block_ref in block_ref_iter: + current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) hits += current_hit misses += current_miss unknowns += current_unknown - with stats.iter_get_s.timer() if stats else nullcontext(): - blocks = ray.get(block_refs) - for block_ref in block_refs: - trace_deallocation(block_ref, loc="iter_batches", free=eager_free) - for block in blocks: - yield block + # TODO(amogkam): Optimized further by batching multiple references in a single + # `ray.get()` call. + block = ray.get(block_ref) + trace_deallocation(block_ref, loc="iter_batches", free=eager_free) + yield block if stats: stats.iter_blocks_local += hits diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 551df1cae632d..66c1fb85108aa 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -22,7 +22,7 @@ def block_generator(num_rows: int, num_blocks: int): def test_resolve_block_refs(ray_start_regular_shared): - block_refs = [[ray.put(0), ray.put(1)], [ray.put(2)]] + block_refs = [ray.put(0), ray.put(1), ray.put(2)] resolved_iter = resolve_block_refs(iter(block_refs)) assert list(resolved_iter) == [0, 1, 2] From bcdc1db5b1373d6daa83655f779aa71934ae01f9 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 15:31:08 -0700 Subject: [PATCH 40/75] comment Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 6540821971215..3d3c5d9e6b861 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -59,6 +59,8 @@ def prefetch_batches_locally( if batch_size: num_rows_to_prefetch = num_batches_to_prefetch * batch_size + else: + num_rows_to_prefetch = None # Create and fetch the initial window. while True: @@ -68,8 +70,12 @@ def prefetch_batches_locally( current_window_size += next_block_ref_and_metadata[1].num_rows except StopIteration: break + # Stop adding if the number of rows in this window is greater than + # requested batch size. if batch_size and current_window_size >= num_rows_to_prefetch: break + # Stop adding if batch_size is None and the number of blocks in this window + # is greater than requested batches to prefetch. elif not batch_size and len(sliding_window) >= num_batches_to_prefetch: break From 1b8c2a6cbb09692ebc9204195593aeedf0552007 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 15:36:17 -0700 Subject: [PATCH 41/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/block_batching.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 478ce41f05ae7..8a4fdce0a85ed 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -109,7 +109,6 @@ def batch_block_refs( prefetcher=prefetcher, stats=stats, num_blocks_to_prefetch=prefetch_blocks, - clear_block_after_read=clear_block_after_read, ), stats=stats, eager_free=eager_free, From d4df6ae4b527812a79c92a2d9f3695646b51f5e5 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 15:43:35 -0700 Subject: [PATCH 42/75] update Signed-off-by: amogkam --- .../block_batching/block_batching.py | 12 +-- .../_internal/block_batching/iter_batches.py | 27 +++--- .../ray/data/_internal/block_batching/util.py | 3 + python/ray/data/dataset_iterator.py | 6 +- .../block_batching/test_block_batching.py | 85 ------------------- 5 files changed, 23 insertions(+), 110 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 1673456b4010c..1cf524f56945b 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -11,7 +11,6 @@ format_batches, collate, extract_data_from_batch, - make_async_gen, WaitBlockPrefetcher, ActorBlockPrefetcher, ) @@ -100,7 +99,6 @@ def batch_block_refs( _prefetch_blocks( block_ref_iter=block_refs, prefetcher=prefetcher, - stats=stats, num_blocks_to_prefetch=prefetch_blocks, ), stats=stats, @@ -155,17 +153,12 @@ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]: ) if collate_fn is not None: - batch_iter = collate(batch_iter, collate_fn=collate_fn) + batch_iter = collate(batch_iter, collate_fn=collate_fn, stats=stats) batch_iter = extract_data_from_batch(batch_iter) yield from batch_iter - if prefetch_batches > 0: - batch_iter = make_async_gen( - blocks, fn=_iterator_fn, num_workers=prefetch_batches - ) - else: - batch_iter = _iterator_fn(blocks) + batch_iter = _iterator_fn(blocks) for formatted_batch in batch_iter: user_timer = stats.iter_user_s.timer() if stats else nullcontext() @@ -177,7 +170,6 @@ def _prefetch_blocks( block_ref_iter: Iterator[ObjectRef[Block]], prefetcher: BlockPrefetcher, num_blocks_to_prefetch: int, - stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[ObjectRef[Block]]: """Given an iterable of Block Object References, returns an iterator over these object reference while prefetching `num_block_to_prefetch` diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index a3f0b4f501db7..b78ef9902a1e3 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -9,7 +9,16 @@ Batch, BlockPrefetcher, ) -from ray.data._internal.block_batching.util import ActorBlockPrefetcher, WaitBlockPrefetcher, resolve_block_refs, blocks_to_batches, format_batches, collate, extract_data_from_batch, make_async_gen +from ray.data._internal.block_batching.util import ( + ActorBlockPrefetcher, + WaitBlockPrefetcher, + resolve_block_refs, + blocks_to_batches, + format_batches, + collate, + extract_data_from_batch, + make_async_gen, +) from ray.data._internal.stats import DatasetStats from ray.data.context import DatasetContext @@ -22,6 +31,7 @@ def nullcontext(enter_result=None): yield enter_result + def iter_batches( block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], *, @@ -119,7 +129,7 @@ def iter_batches( def _async_iter_batches( block_refs: Iterator[ObjectRef[Block]], ) -> Iterator[DataBatch]: - + if prefetch_batches > 0: # Step 1: Prefetch logical batches locally. block_refs = prefetch_batches_locally( @@ -131,11 +141,9 @@ def _async_iter_batches( # Step 2: Resolve the blocks. block_iter = resolve_block_refs( - block_ref_iter=block_refs, - eager_free=eager_free, - stats=stats + block_ref_iter=block_refs, eager_free=eager_free, stats=stats ) - + # Step 3: Batch and shuffle the resolved blocks. batch_iter = blocks_to_batches( block_iter=block_iter, @@ -144,7 +152,7 @@ def _async_iter_batches( drop_last=drop_last, shuffle_buffer_min_size=shuffle_buffer_min_size, shuffle_seed=shuffle_seed, - ensure_copy=ensure_copy + ensure_copy=ensure_copy, ) # Step 4: Use a threadpool for formatting and collation. @@ -163,9 +171,7 @@ def _async_iter_batches( # Run everything in a separate thread to not block the main thread when waiting # for streaming results. - async_batch_iter = make_async_gen( - block_refs, fn=_async_iter_batches, num_workers=1 - ) + async_batch_iter = make_async_gen(block_refs, fn=_async_iter_batches, num_workers=1) while True: with stats.iter_total_blocked_s.timer() if stats else nullcontext(): @@ -222,6 +228,7 @@ def threadpool_computations( else: return threadpool_computations(batch_iter) + def prefetch_batches_locally( block_ref_iter: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], prefetcher: BlockPrefetcher, diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index e6a0bbffed12b..30ddc79237332 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -6,6 +6,7 @@ import ray from ray.types import ObjectRef +from ray.actor import ActorHandle from ray.data.block import Block, BlockAccessor, DataBatch from ray.data._internal.batcher import Batcher, ShufflingBatcher from ray.data._internal.block_batching.interfaces import ( @@ -196,6 +197,7 @@ def extract_data_from_batch(batch_iter: Iterator[Batch]) -> Iterator[Any]: for batch in batch_iter: yield batch.data + def make_async_gen( base_iterator: Iterator[T], fn: Callable[[Iterator[T]], Iterator[U]], @@ -279,6 +281,7 @@ def execute_computation(thread_index: int): if num_threads_finished >= num_workers: break + PREFETCHER_ACTOR_NAMESPACE = "ray.dataset" diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 57117ebff53bf..0016813debef8 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -135,11 +135,7 @@ def iter_batches( """ context = DatasetContext.get_current() - if not context.use_streaming_executor: - # Always use legacy iter_batches for bulk executor. - use_legacy = True - else: - use_legacy = context.use_legacy_iter_batches + use_legacy = context.use_legacy_iter_batches if prefetch_blocks > 0 and not use_legacy: raise DeprecationWarning( diff --git a/python/ray/data/tests/block_batching/test_block_batching.py b/python/ray/data/tests/block_batching/test_block_batching.py index d0158d8350af3..67b44d8b95424 100644 --- a/python/ray/data/tests/block_batching/test_block_batching.py +++ b/python/ray/data/tests/block_batching/test_block_batching.py @@ -73,91 +73,6 @@ def prefetch_blocks(self, blocks: List[Block]): assert all(len(window) == num_blocks_to_prefetch for window in windows) -<<<<<<< HEAD -@pytest.mark.parametrize("block_size", [1, 10]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_blocks_to_batches(block_size, drop_last): - num_blocks = 5 - block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) - - batch_size = 3 - batch_iter = _blocks_to_batches( - block_iter, batch_size=batch_size, drop_last=drop_last - ) - - if drop_last: - for batch in batch_iter: - assert len(batch) == batch_size - else: - full_batches = 0 - leftover_batches = 0 - - dataset_size = block_size * num_blocks - for batch in batch_iter: - if len(batch) == batch_size: - full_batches += 1 - if len(batch) == (dataset_size % batch_size): - leftover_batches += 1 - - assert leftover_batches == 1 - assert full_batches == (dataset_size // batch_size) - - -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) -def test_format_batches(batch_format): - block_iter = block_generator(num_rows=2, num_blocks=2) - batch_iter = _format_batches(block_iter, batch_format=batch_format) - - for batch in batch_iter: - if batch_format == "pandas": - assert isinstance(batch, pd.DataFrame) - elif batch_format == "arrow": - assert isinstance(batch, pa.Table) - elif batch_format == "numpy": - assert isinstance(batch, dict) - assert isinstance(batch["foo"], np.ndarray) -======= -# Test for 3 cases -# 1. Batch size is less than block size -# 2. Batch size is more than block size -# 3. Block size is not divisble by batch size -@pytest.mark.parametrize("batch_size", [4, 10, 7]) -def test_async_batch_fetching(batch_size): - blocks = block_generator(num_blocks=5, num_rows=8) - - def sleep_batch_format(batch_iter, *args, **kwargs): - for batch in batch_iter: - time.sleep(2) - yield batch - - with mock.patch( - "ray.data._internal.block_batching.util.format_batches", - sleep_batch_format, - ): - batch_iter = batch_blocks( - batch_size=batch_size, blocks=blocks, prefetch_batches=1 - ) - outputs = [] - start_time = time.time() - for batch in batch_iter: - time.sleep(3) - outputs.append(batch) - end_time = time.time() - - total_time = end_time - start_time - # Total time should be based on number of times the udf is called - # (which is equal to len(outputs)). - # The 2 seconds sleep in sleep_batch_format is overlapped, so does not count - # towards total time. - assert total_time < len(outputs) * 3 + 3 - - # There should be no dropped rows. - assert sum(len(output_batch) for output_batch in outputs) == 40, sum( - len(output_batch) for output_batch in outputs - ) # 5 blocks with 8 rows each. ->>>>>>> 0feeb2d4a0a98405254ae1c6b497d79d91a08e3e - - if __name__ == "__main__": import sys From 0145b2da593fb2353f4d4f0b5214902624e77254 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 16:16:10 -0700 Subject: [PATCH 43/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 9 ++++++++- .../ray/data/tests/block_batching/test_iter_batches.py | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index b78ef9902a1e3..4630ee797cace 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -127,7 +127,7 @@ def iter_batches( eager_free = clear_block_after_read and DatasetContext.get_current().eager_free def _async_iter_batches( - block_refs: Iterator[ObjectRef[Block]], + block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], ) -> Iterator[DataBatch]: if prefetch_batches > 0: @@ -138,6 +138,13 @@ def _async_iter_batches( num_batches_to_prefetch=prefetch_batches, batch_size=batch_size, ) + else: + + def _drop_metadata(block_ref_iter): + for block_ref, metadata in block_ref_iter: + yield block_ref + + block_refs = _drop_metadata(block_refs) # Step 2: Resolve the blocks. block_iter = resolve_block_refs( diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 77a89772d493d..73f4e307fcd76 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,3 +1,4 @@ +import itertools import pytest import time from typing import Iterator, List, Tuple @@ -5,6 +6,7 @@ import pandas as pd import pyarrow as pa +import ray from ray.data.block import Block, BlockMetadata from ray.data._internal.block_batching.interfaces import ( Batch, @@ -104,7 +106,8 @@ def test_iter_batches_e2e(ray_start_regular_shared, batch_size, drop_last): def collate_fn(batch: pd.DataFrame): return batch + 1 - block_refs_iter = block_generator(num_blocks=4, num_rows=2) + block_refs_iter = itertools.starmap(lambda block, metadata: (ray.put(block), metadata), block_generator(num_blocks=4, num_rows=2)) + output_batches = iter_batches( block_refs_iter, From 58d73ec5265d712429e832b5120e2d2d2509f4cb Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 16:23:17 -0700 Subject: [PATCH 44/75] address comments Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 29 +++++++++---------- .../ray/data/_internal/block_batching/util.py | 3 +- .../tests/block_batching/test_iter_batches.py | 8 ++--- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 3d3c5d9e6b861..c755533d7caad 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -31,8 +31,8 @@ 3. Perform the necessary batch slicing to construct full batches, possibly shuffling if necessary. 4. Then, in a threadpool consisting of `prefetch_batches` threads: - 3. Format the batches to the provided batch format. - 4. Apply the collate function. + a. Format the batches to the provided batch format. + b. Apply the collate function. 5. Fetch outputs from the threadpool, maintaining order of the batches. """ @@ -57,34 +57,31 @@ def prefetch_batches_locally( sliding_window = collections.deque() current_window_size = 0 - if batch_size: + if batch_size is None: num_rows_to_prefetch = num_batches_to_prefetch * batch_size else: num_rows_to_prefetch = None # Create and fetch the initial window. - while True: + # Stop adding if the number of rows in this window is greater than requested + # batch size, or if the batch size is None and the number of blocks in this window + # is greater than requested batches to prefetch. + while (batch_size is not None and current_window_size >= num_rows_to_prefetch) or ( + batch_size is None and len(sliding_window) >= num_batches_to_prefetch + ): try: next_block_ref_and_metadata = next(block_ref_iter) - sliding_window.append(next_block_ref_and_metadata) - current_window_size += next_block_ref_and_metadata[1].num_rows except StopIteration: break - # Stop adding if the number of rows in this window is greater than - # requested batch size. - if batch_size and current_window_size >= num_rows_to_prefetch: - break - # Stop adding if batch_size is None and the number of blocks in this window - # is greater than requested batches to prefetch. - elif not batch_size and len(sliding_window) >= num_batches_to_prefetch: - break + sliding_window.append(next_block_ref_and_metadata) + current_window_size += next_block_ref_and_metadata[1].num_rows prefetcher.prefetch_blocks([block_ref for block_ref, _ in list(sliding_window)]) while sliding_window: block_ref, metadata = sliding_window.popleft() current_window_size -= metadata.num_rows - if not batch_size or current_window_size < num_rows_to_prefetch: + if batch_size is None or current_window_size < num_rows_to_prefetch: try: sliding_window.append(next(block_ref_iter)) prefetcher.prefetch_blocks( @@ -95,7 +92,7 @@ def prefetch_batches_locally( yield block_ref -def restore_from_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: +def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 21484b13df806..6e4355bc6c06a 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -71,7 +71,8 @@ def resolve_block_refs( # TODO(amogkam): Optimized further by batching multiple references in a single # `ray.get()` call. - block = ray.get(block_ref) + with stats.iter_get_s.timer() if stats else nullcontext(): + block = ray.get(block_ref) trace_deallocation(block_ref, loc="iter_batches", free=eager_free) yield block diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 6a4afbdffcb92..b6e722970252f 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -10,7 +10,7 @@ ) from ray.data._internal.block_batching.iter_batches import ( prefetch_batches_locally, - restore_from_original_order, + restore_original_order, ) @@ -35,7 +35,7 @@ def __init__(self): self.windows = [] def prefetch_blocks(self, blocks: List[Block]): - if not batch_size: + if batch_size is None: assert len(blocks) == num_batches_to_prefetch else: assert ( @@ -68,7 +68,7 @@ def prefetch_blocks(self, blocks: List[Block]): # block. assert len(prefetcher.windows) == previous_num_windows + 1 previous_num_windows = len(prefetcher.windows) - elif batch_size and remaining_rows > batch_size * num_batches_to_prefetch: + elif batch_size is not None and remaining_rows > batch_size * num_batches_to_prefetch: # Test that we are actually prefetching in advance if this is not the last # batch. assert len(prefetcher.windows) == previous_num_windows + 1 @@ -86,7 +86,7 @@ def test_restore_from_original_order(): Batch(2, None), ] - ordered = list(restore_from_original_order(iter(base_iterator))) + ordered = list(restore_original_order(iter(base_iterator))) idx = [batch.batch_idx for batch in ordered] assert idx == [0, 1, 2, 3] From 620d52e80b13551b618eb027b533abf7040e7f06 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 16:26:17 -0700 Subject: [PATCH 45/75] update Signed-off-by: amogkam --- .../ray/data/_internal/block_batching/iter_batches.py | 2 +- .../data/tests/block_batching/test_iter_batches.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 2283e2c50e15d..03ada3b027363 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -172,7 +172,7 @@ def _drop_metadata(block_ref_iter): ) # Step 5: Restore original order. - batch_iter: Iterator[Batch] = restore_from_original_order(batch_iter) + batch_iter: Iterator[Batch] = restore_original_order(batch_iter) yield from extract_data_from_batch(batch_iter) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 7108dff43bba9..fa3c336a669e6 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -73,7 +73,10 @@ def prefetch_blocks(self, blocks: List[Block]): # block. assert len(prefetcher.windows) == previous_num_windows + 1 previous_num_windows = len(prefetcher.windows) - elif batch_size is not None and remaining_rows > batch_size * num_batches_to_prefetch: + elif ( + batch_size is not None + and remaining_rows > batch_size * num_batches_to_prefetch + ): # Test that we are actually prefetching in advance if this is not the last # batch. assert len(prefetcher.windows) == previous_num_windows + 1 @@ -106,8 +109,10 @@ def test_iter_batches_e2e(ray_start_regular_shared, batch_size, drop_last): def collate_fn(batch: pd.DataFrame): return batch + 1 - block_refs_iter = itertools.starmap(lambda block, metadata: (ray.put(block), metadata), block_generator(num_blocks=4, num_rows=2)) - + block_refs_iter = itertools.starmap( + lambda block, metadata: (ray.put(block), metadata), + block_generator(num_blocks=4, num_rows=2), + ) output_batches = iter_batches( block_refs_iter, From 3b8494c8dab4950b8f75863aa7bc40aca4d98486 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 16:32:02 -0700 Subject: [PATCH 46/75] update Signed-off-by: amogkam --- python/ray/data/dataset_iterator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 0016813debef8..7ee52b755c888 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -135,7 +135,13 @@ def iter_batches( """ context = DatasetContext.get_current() - use_legacy = context.use_legacy_iter_batches + if not context.use_streaming_executor: + # Always use legacy iter_batches for bulk executor. + use_legacy = True + if not prefetch_blocks and prefetch_batches: + prefetch_blocks = prefetch_batches + else: + use_legacy = context.use_legacy_iter_batches if prefetch_blocks > 0 and not use_legacy: raise DeprecationWarning( From 9d111af9b9889afe332f5f9eb58bcfb9bda5f0ef Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 16:46:58 -0700 Subject: [PATCH 47/75] lint Signed-off-by: amogkam --- python/ray/data/tests/block_batching/test_iter_batches.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b6e722970252f..f27be2db7aa99 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -68,7 +68,10 @@ def prefetch_blocks(self, blocks: List[Block]): # block. assert len(prefetcher.windows) == previous_num_windows + 1 previous_num_windows = len(prefetcher.windows) - elif batch_size is not None and remaining_rows > batch_size * num_batches_to_prefetch: + elif ( + batch_size is not None + and remaining_rows > batch_size * num_batches_to_prefetch + ): # Test that we are actually prefetching in advance if this is not the last # batch. assert len(prefetcher.windows) == previous_num_windows + 1 From 497eb825ef0789739882b48bc4ad6b8e31891d89 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:00:30 -0700 Subject: [PATCH 48/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index c755533d7caad..88ff1ab499d4d 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -57,7 +57,7 @@ def prefetch_batches_locally( sliding_window = collections.deque() current_window_size = 0 - if batch_size is None: + if batch_size is not None: num_rows_to_prefetch = num_batches_to_prefetch * batch_size else: num_rows_to_prefetch = None From ddb94607dd961578190f433d514701065ef06866 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:21:20 -0700 Subject: [PATCH 49/75] lock stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 624997887237e..2f679ac57a2c6 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -47,7 +47,8 @@ def timer(self) -> None: try: yield finally: - self.add(time.thread_time() - time_start) + with self.lock: + self.add(time.thread_time() - time_start) def add(self, value: float) -> None: self._value += value From 5d4587a90e98c02f08014c50f08c15e260fcc8ea Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:22:10 -0700 Subject: [PATCH 50/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/iter_batches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 88ff1ab499d4d..24f8ba816e79e 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -66,8 +66,8 @@ def prefetch_batches_locally( # Stop adding if the number of rows in this window is greater than requested # batch size, or if the batch size is None and the number of blocks in this window # is greater than requested batches to prefetch. - while (batch_size is not None and current_window_size >= num_rows_to_prefetch) or ( - batch_size is None and len(sliding_window) >= num_batches_to_prefetch + while (batch_size is not None and current_window_size < num_rows_to_prefetch) or ( + batch_size is None and len(sliding_window) < num_batches_to_prefetch ): try: next_block_ref_and_metadata = next(block_ref_iter) From 6313be7f1500b108c306b5d23f2a4763a2310b12 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:33:47 -0700 Subject: [PATCH 51/75] remove lock Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 2f679ac57a2c6..2f7c1f7c7a3d6 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import time from contextlib import contextmanager -import threading from typing import Dict, List, Optional, Set, Tuple, Union, Any import numpy as np @@ -39,16 +38,13 @@ def __init__(self): self._max: float = 0 self._total_count: float = 0 - self.lock = threading.Lock() - @contextmanager def timer(self) -> None: time_start = time.thread_time() try: yield finally: - with self.lock: - self.add(time.thread_time() - time_start) + self.add(time.thread_time() - time_start) def add(self, value: float) -> None: self._value += value From 171aa144d98c84612ab482e67ce6fc01916a4878 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:45:48 -0700 Subject: [PATCH 52/75] fix stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 2f7c1f7c7a3d6..eac3f95834663 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -671,39 +671,34 @@ def __str__(self) -> str: def to_string(self) -> str: out = "" - if ( - self.total_time.get() - or self.next_time.get() - or self.format_time.get() - or self.get_time.get() - or self.user_time.get() - or self.block_time.get() - or self.collate_time.get() - ): - out += "\nDataset iterator time breakdown:\n" + out += "\nDataset iterator time breakdown:\n" + if self.block_time.get(): out += "* Total time user code is blocked: {}\n".format( fmt(self.block_time.get()) ) + if self.user_time.get(): out += "* Total time in user code: {}\n".format(fmt(self.user_time.get())) + if self.total_time.get(): out += "* Total time overall: {}\n".format(fmt(self.total_time.get())) - out += "* Num blocks local: {}\n".format(self.iter_blocks_local) - out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) - out += "* Num blocks unknown location: {}\n".format( - self.iter_unknown_location - ) - out += "* Batch iteration time breakdown:\n" + out += "* Num blocks local: {}\n".format(self.iter_blocks_local) + out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) + out += "* Num blocks unknown location: {}\n".format(self.iter_unknown_location) + out += "* Batch iteration time breakdown:\n" + if self.get_time.get(): out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( fmt(self.get_time.min()), fmt(self.get_time.max()), fmt(self.get_time.avg()), fmt(self.get_time.get()), ) + if self.next_time.get(): out += " * In batch creation: {} min, {} max, {} avg, {} total\n".format( fmt(self.next_time.min()), fmt(self.next_time.max()), fmt(self.next_time.avg()), fmt(self.next_time.get()), ) + if self.format_time.get(): out += ( " * In batch formatting: {} min, {} max, {} avg, {} total\n".format( fmt(self.format_time.min()), @@ -712,6 +707,7 @@ def to_string(self) -> str: fmt(self.format_time.get()), ) ) + if self.collate_time.get(): out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format( fmt(self.collate_time.min()), fmt(self.collate_time.max()), From 1c7dfe93df06df61b784b3b9f53bfcf39b5ad34e Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:48:01 -0700 Subject: [PATCH 53/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 30ddc79237332..3051ef35146d5 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -77,9 +77,9 @@ def resolve_block_refs( yield block if stats: - stats.iter_blocks_local += hits - stats.iter_blocks_remote += misses - stats.iter_unknown_location += unknowns + stats.iter_blocks_local = hits + stats.iter_blocks_remote = misses + stats.iter_unknown_location = unknowns def blocks_to_batches( From deb7b979fd79f9d5d68fa0bd170960aa6cc42c92 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:48:26 -0700 Subject: [PATCH 54/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 6e4355bc6c06a..a7eb7e99c5126 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -77,9 +77,9 @@ def resolve_block_refs( yield block if stats: - stats.iter_blocks_local += hits - stats.iter_blocks_remote += misses - stats.iter_unknown_location += unknowns + stats.iter_blocks_local = hits + stats.iter_blocks_remote = misses + stats.iter_unknown_location = unknowns def blocks_to_batches( From d47dae1f5ac48e5535c13de49569127536389f30 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 17:52:54 -0700 Subject: [PATCH 55/75] fix Signed-off-by: amogkam --- release/release_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 0e8ee310789e8..42fb33b5d3f29 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -3981,7 +3981,7 @@ team: data cluster: cluster_env: app_config.yaml - cluster_compute: multi_node_benchmark_compute_yaml + cluster_compute: multi_node_benchmark_compute.yaml run: # Expect the benchmark to finish around 30 minutes. From 771d7c9a1f107ff63255154e151303cd7da8b7ff Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 20:46:08 -0700 Subject: [PATCH 56/75] ci fixes Signed-off-by: amogkam --- doc/source/ray-air/doc_code/air_ingest.py | 2 +- python/ray/air/util/check_ingest.py | 24 ++++++++++++------- .../block_batching/block_batching.py | 3 ++- .../ray/data/_internal/block_batching/util.py | 2 +- python/ray/data/dataset_iterator.py | 2 -- .../tests/block_batching/test_iter_batches.py | 5 +++- .../data/tests/test_dataset_consumption.py | 10 ++++---- .../workloads/data_benchmark.py | 2 +- .../dataset/data_ingest_benchmark.py | 4 ++-- 9 files changed, 32 insertions(+), 22 deletions(-) diff --git a/doc/source/ray-air/doc_code/air_ingest.py b/doc/source/ray-air/doc_code/air_ingest.py index f8ec5442d44b7..6f2e29d6a58ac 100644 --- a/doc/source/ray-air/doc_code/air_ingest.py +++ b/doc/source/ray-air/doc_code/air_ingest.py @@ -27,7 +27,7 @@ datasets={"train": dataset}, preprocessor=preprocessor, num_epochs=1, # Stop after this number of epochs is read. - prefetch_blocks=1, # Number of blocks to prefetch when reading data. + prefetch_batches=1, # Number of batches to prefetch when reading data. batch_size=None, # Use whole blocks as batches. ) trainer.fit() diff --git a/python/ray/air/util/check_ingest.py b/python/ray/air/util/check_ingest.py index c91ac0330d909..08c43a9f87305 100755 --- a/python/ray/air/util/check_ingest.py +++ b/python/ray/air/util/check_ingest.py @@ -29,7 +29,7 @@ class DummyTrainer(DataParallelTrainer): scaling_config: Configuration for how to scale training. This is the same as for :class:`~ray.train.base_trainer.BaseTrainer`. num_epochs: How many many times to iterate through the datasets for. - prefetch_blocks: The number of blocks to prefetch ahead of the + prefetch_batches: The number of batches to prefetch ahead of the current block during the scan. This is the same as :meth:`~ray.data.dataset.Dataset.iter_batches` time_preprocessing_separately: Whether to time the preprocessing separately @@ -44,16 +44,17 @@ def __init__( *args, scaling_config: Optional[ScalingConfig] = None, num_epochs: int = 1, - prefetch_blocks: int = 1, + prefetch_batches: int = 1, batch_size: Optional[int] = 4096, time_preprocessing_separately: bool = False, - **kwargs, + # Deprecated. + prefetch_blocks: int = 0**kwargs, ): if not scaling_config: scaling_config = ScalingConfig(num_workers=1) super().__init__( train_loop_per_worker=DummyTrainer.make_train_loop( - num_epochs, prefetch_blocks, batch_size + num_epochs, prefetch_batches, batch_size ), *args, scaling_config=scaling_config, @@ -81,7 +82,10 @@ def preprocess_datasets(self): @staticmethod def make_train_loop( - num_epochs: int, prefetch_blocks: int, batch_size: Optional[int] + num_epochs: int, + prefetch_batches: int, + prefetch_blocks: int, + batch_size: Optional[int], ): """Make a debug train loop that runs for the given amount of epochs.""" @@ -99,7 +103,9 @@ def train_loop_per_worker(): epochs_read += 1 batch_start = time.perf_counter() for batch in data_shard.iter_batches( - prefetch_blocks=prefetch_blocks, batch_size=batch_size + prefetch_batches=prefetch_batches, + prefetch_blocks=prefetch_blocks, + batch_size=batch_size, ): batch_delay = time.perf_counter() - batch_start batch_delays.append(batch_delay) @@ -189,11 +195,11 @@ def make_local_dataset_iterator( "--num-epochs", "-e", type=int, default=1, help="Number of epochs to read." ) parser.add_argument( - "--prefetch-blocks", + "--prefetch-batches", "-b", type=int, default=1, - help="Number of blocks to prefetch when reading data.", + help="Number of batches to prefetch when reading data.", ) args = parser.parse_args() @@ -215,7 +221,7 @@ def make_local_dataset_iterator( datasets={"train": dataset}, preprocessor=preprocessor, num_epochs=args.num_epochs, - prefetch_blocks=args.prefetch_blocks, + prefetch_batches=args.prefetch_batches, dataset_config={"train": DatasetConfig()}, batch_size=None, ) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 1cf524f56945b..95e2c6aec3bc4 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -81,7 +81,8 @@ def batch_block_refs( Returns: An iterator over record batches. """ - stats._legacy_iter_batches = True + if stats: + stats._legacy_iter_batches = True context = DatasetContext.get_current() if ( diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 3051ef35146d5..0cdddef9a7b22 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -179,7 +179,7 @@ def format_batches( def collate( batch_iter: Iterator[Batch], collate_fn: Optional[Callable[[DataBatch], Any]], - stats: Optional[DatasetStats], + stats: Optional[DatasetStats] = None, ) -> Iterator[CollatedBatch]: """Returns an iterator with the provided collate_fn applied to items of the batch iterator. diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 7ee52b755c888..57117ebff53bf 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -138,8 +138,6 @@ def iter_batches( if not context.use_streaming_executor: # Always use legacy iter_batches for bulk executor. use_legacy = True - if not prefetch_blocks and prefetch_batches: - prefetch_blocks = prefetch_batches else: use_legacy = context.use_legacy_iter_batches diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index fa3c336a669e6..53a0892569bbb 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -153,7 +153,10 @@ def collate_fn(batch): time.sleep(2) return batch - block_refs_iter = block_generator(num_blocks=20, num_rows=2) + block_refs_iter = itertools.starmap( + lambda block, metadata: (ray.put(block), metadata), + block_generator(num_blocks=20, num_rows=2), + ) start_time = time.time() output_batches = iter_batches( block_refs_iter, batch_size=None, collate_fn=collate_fn, prefetch_batches=4 diff --git a/python/ray/data/tests/test_dataset_consumption.py b/python/ray/data/tests/test_dataset_consumption.py index f4882dc4e743b..cfd216e76304d 100644 --- a/python/ray/data/tests/test_dataset_consumption.py +++ b/python/ray/data/tests/test_dataset_consumption.py @@ -609,7 +609,7 @@ def test_iter_batches_basic(ray_start_regular_shared): # Prefetch. batches = list( - ds.iter_batches(prefetch_blocks=1, batch_size=None, batch_format="pandas") + ds.iter_batches(prefetch_batches=1, batch_size=None, batch_format="pandas") ) assert len(batches) == len(dfs) for batch, df in zip(batches, dfs): @@ -618,7 +618,9 @@ def test_iter_batches_basic(ray_start_regular_shared): batch_size = 2 batches = list( - ds.iter_batches(prefetch_blocks=2, batch_size=batch_size, batch_format="pandas") + ds.iter_batches( + prefetch_batches=2, batch_size=batch_size, batch_format="pandas" + ) ) assert all(len(batch) == batch_size for batch in batches) assert len(batches) == math.ceil( @@ -631,7 +633,7 @@ def test_iter_batches_basic(ray_start_regular_shared): # Prefetch more than number of blocks. batches = list( ds.iter_batches( - prefetch_blocks=len(dfs), batch_size=None, batch_format="pandas" + prefetch_batches=len(dfs), batch_size=None, batch_format="pandas" ) ) assert len(batches) == len(dfs) @@ -645,7 +647,7 @@ def test_iter_batches_basic(ray_start_regular_shared): try: context.actor_prefetcher_enabled = False batches = list( - ds.iter_batches(prefetch_blocks=1, batch_size=None, batch_format="pandas") + ds.iter_batches(prefetch_batches=1, batch_size=None, batch_format="pandas") ) assert len(batches) == len(dfs) for batch, df in zip(batches, dfs): diff --git a/release/air_tests/air_benchmarks/workloads/data_benchmark.py b/release/air_tests/air_benchmarks/workloads/data_benchmark.py index 837704ac80906..dc34435cc6487 100644 --- a/release/air_tests/air_benchmarks/workloads/data_benchmark.py +++ b/release/air_tests/air_benchmarks/workloads/data_benchmark.py @@ -34,7 +34,7 @@ def run_ingest_bulk(dataset, num_workers, num_cpus_per_worker): datasets={"train": dataset}, preprocessor=dummy_prep, num_epochs=1, - prefetch_blocks=1, + prefetch_batches=1, dataset_config={"train": DatasetConfig(split=True)}, ) trainer.fit() diff --git a/release/nightly_tests/dataset/data_ingest_benchmark.py b/release/nightly_tests/dataset/data_ingest_benchmark.py index ec598e04a4aed..f36870638df9f 100644 --- a/release/nightly_tests/dataset/data_ingest_benchmark.py +++ b/release/nightly_tests/dataset/data_ingest_benchmark.py @@ -27,7 +27,7 @@ def get_location(self): def DoConsume(split, rank): - prefetch_blocks = 1 + prefetch_batches = 1 batch_size = 4096 num_epochs = 1 @@ -51,7 +51,7 @@ def generate_epochs(data, epochs: int): epochs_read += 1 batch_start = time.perf_counter() for batch in epoch_data.iter_batches( - prefetch_blocks=prefetch_blocks, batch_size=batch_size + prefetch_batches=prefetch_batches, batch_size=batch_size ): batch_delay = time.perf_counter() - batch_start batch_delays.append(batch_delay) From 3c752ac1b815c287deac81e1ff0e1957a27731e1 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 20:54:21 -0700 Subject: [PATCH 57/75] more fixes Signed-off-by: amogkam --- python/ray/data/tests/test_bulk_executor.py | 1 + python/ray/train/batch_predictor.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_bulk_executor.py b/python/ray/data/tests/test_bulk_executor.py index a5d84a0b3510a..c1135f82a07ce 100644 --- a/python/ray/data/tests/test_bulk_executor.py +++ b/python/ray/data/tests/test_bulk_executor.py @@ -98,6 +98,7 @@ def test_basic_stats(ray_start_10_cpus_shared): # TODO(ekl) remove this test once we have the new backend on by default. def test_e2e_bulk_sanity(ray_start_10_cpus_shared): DatasetContext.get_current().new_execution_backend = True + DatasetContext.get_current().use_streaming_executor = False result = ray.data.range(5).map(lambda x: x + 1) assert result.take_all() == [1, 2, 3, 4, 5], result diff --git a/python/ray/train/batch_predictor.py b/python/ray/train/batch_predictor.py index 82345a70f395d..392be5501c34a 100644 --- a/python/ray/train/batch_predictor.py +++ b/python/ray/train/batch_predictor.py @@ -354,7 +354,6 @@ def __call__(self, input_batch: DataBatchType) -> DataBatchType: if override_prep is not None else predict_stage_batch_format, batch_size=batch_size, - prefetch_batches=int(num_gpus_per_worker > 0), fn_constructor_kwargs={"override_prep": override_prep}, **ray_remote_args, ) From 58691074637b0ce605cc30930abd66137ac0edac Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 20:59:01 -0700 Subject: [PATCH 58/75] fix Signed-off-by: amogkam --- python/ray/air/util/check_ingest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/air/util/check_ingest.py b/python/ray/air/util/check_ingest.py index 08c43a9f87305..6f8d9f3d43f42 100755 --- a/python/ray/air/util/check_ingest.py +++ b/python/ray/air/util/check_ingest.py @@ -48,7 +48,8 @@ def __init__( batch_size: Optional[int] = 4096, time_preprocessing_separately: bool = False, # Deprecated. - prefetch_blocks: int = 0**kwargs, + prefetch_blocks: int = 0, + **kwargs, ): if not scaling_config: scaling_config = ScalingConfig(num_workers=1) From e65f73f315ddb52e7ea79bd962a60324796cbe1d Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 21:03:52 -0700 Subject: [PATCH 59/75] fix stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 7383491203511..6b84e7d002f6a 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -306,7 +306,7 @@ def to_summary(self) -> "DatasetStatsSummary": iter_stats = IterStatsSummary( self._legacy_iter_batches, - self._iter_wait_s, + self.iter_wait_s, self.iter_get_s, self.iter_next_batch_s, self.iter_format_batch_s, From f84b9eb921d55c729441cc9674e15ba507ecfde8 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 21:17:56 -0700 Subject: [PATCH 60/75] more fix Signed-off-by: amogkam --- .../dataset_iterator/pipelined_dataset_iterator.py | 9 +++++++-- python/ray/data/dataset_iterator.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py index f776827571886..af298243a539a 100644 --- a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py +++ b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py @@ -34,8 +34,13 @@ def _to_block_iterator( ) -> Tuple[ Iterator[Tuple[ObjectRef[Block], BlockMetadata]], Optional[DatasetStats] ]: - ds = self._get_next_dataset() - return ds.iterator()._to_block_iterator() + epoch_pipeline = self._get_next_dataset() + + def block_iter(): + for ds in epoch_pipeline.iter_datasets(): + yield from ds._plan.execute().iter_blocks_with_metadata() + + return block_iter(), None def stats(self) -> str: return self._base_dataset_pipeline.stats() diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 57117ebff53bf..dfd08269b35d3 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -185,7 +185,8 @@ def drop_metadata(block_iterator): clear_block_after_read=True, ) - stats.iter_total_s.add(time.perf_counter() - time_start) + if stats: + stats.iter_total_s.add(time.perf_counter() - time_start) def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]: """Return a local row iterator over the dataset. From 738b2257eb9969c8249e8ddb50e26952fad3f130 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 21:41:11 -0700 Subject: [PATCH 61/75] change back trace deallocation Signed-off-by: amogkam --- .../ray/data/_internal/block_batching/block_batching.py | 7 ++++++- python/ray/data/_internal/block_batching/iter_batches.py | 9 ++++++--- python/ray/data/_internal/block_batching/util.py | 4 ---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 66b42ddd615e8..d0e7ab4a95218 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -14,6 +14,7 @@ WaitBlockPrefetcher, ActorBlockPrefetcher, ) +from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.data.block import Block, DataBatch from ray.data.context import DatasetContext @@ -101,9 +102,9 @@ def batch_block_refs( block_ref_iter=block_refs, prefetcher=prefetcher, num_blocks_to_prefetch=prefetch_blocks, + eager_free=eager_free, ), stats=stats, - eager_free=eager_free, ) yield from batch_blocks( @@ -171,6 +172,7 @@ def _prefetch_blocks( block_ref_iter: Iterator[ObjectRef[Block]], prefetcher: BlockPrefetcher, num_blocks_to_prefetch: int, + eager_free: bool = False, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[ObjectRef[Block]]: """Given an iterable of Block Object References, returns an iterator @@ -204,3 +206,6 @@ def _prefetch_blocks( except StopIteration: pass yield block_ref + trace_deallocation( + block_ref, "block_batching._prefetch_blocks", free=eager_free + ) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index abed089c1e047..45347a80d534d 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -19,6 +19,7 @@ extract_data_from_batch, make_async_gen, ) +from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetStats from ray.data.context import DatasetContext @@ -137,6 +138,7 @@ def _async_iter_batches( prefetcher=prefetcher, num_batches_to_prefetch=prefetch_batches, batch_size=batch_size, + eager_free=eager_free, ) else: @@ -147,9 +149,7 @@ def _drop_metadata(block_ref_iter): block_refs = _drop_metadata(block_refs) # Step 2: Resolve the blocks. - block_iter = resolve_block_refs( - block_ref_iter=block_refs, eager_free=eager_free, stats=stats - ) + block_iter = resolve_block_refs(block_ref_iter=block_refs, stats=stats) # Step 3: Batch and shuffle the resolved blocks. batch_iter = blocks_to_batches( @@ -241,6 +241,7 @@ def prefetch_batches_locally( prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, batch_size: Optional[int], + eager_free: bool = False, ) -> Iterator[ObjectRef[Block]]: """Given an iterator of batched block references, returns an iterator over the same block references while prefetching `num_batches_to_prefetch` batches in advance. @@ -251,6 +252,7 @@ def prefetch_batches_locally( num_batches_to_prefetch: The number of batches to prefetch ahead of the current batch during the scan. batch_size: User specified batch size, or None to let the system pick. + eager_free: Whether to eagerly free the object reference from the object store. """ sliding_window = collections.deque() @@ -289,6 +291,7 @@ def prefetch_batches_locally( except StopIteration: pass yield block_ref + trace_deallocation(block_ref, loc="iter_batches", free=eager_free) def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 0cdddef9a7b22..7b6782100023f 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -14,7 +14,6 @@ CollatedBatch, BlockPrefetcher, ) -from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.stats import DatasetPipelineStats, DatasetStats from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -49,14 +48,12 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], - eager_free: bool = False, stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None, ) -> Iterator[Block]: """Resolves the block references for each logical batch. Args: block_ref_iter: An iterator over block object references. - eager_free: Whether to eagerly free the object reference from the object store. stats: An optional stats object to recording block hits and misses. """ hits = 0 @@ -73,7 +70,6 @@ def resolve_block_refs( # `ray.get()` call. with stats.iter_get_s.timer() if stats else nullcontext(): block = ray.get(block_ref) - trace_deallocation(block_ref, loc="iter_batches", free=eager_free) yield block if stats: From 06bbe2bfabc3ff0437eb144d9d43efb9e6b32e9a Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 21:52:48 -0700 Subject: [PATCH 62/75] fix Signed-off-by: amogkam --- python/ray/air/util/check_ingest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/air/util/check_ingest.py b/python/ray/air/util/check_ingest.py index 6f8d9f3d43f42..feae9960e43e4 100755 --- a/python/ray/air/util/check_ingest.py +++ b/python/ray/air/util/check_ingest.py @@ -55,7 +55,7 @@ def __init__( scaling_config = ScalingConfig(num_workers=1) super().__init__( train_loop_per_worker=DummyTrainer.make_train_loop( - num_epochs, prefetch_batches, batch_size + num_epochs, prefetch_batches, prefetch_blocks, batch_size ), *args, scaling_config=scaling_config, From c5da60190d8e4ff63dcffd7d2ace525aa70673fb Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 22:26:13 -0700 Subject: [PATCH 63/75] empty Signed-off-by: amogkam From fa3d13b1ea8cab97922dc5854a316577949eb129 Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 23:07:27 -0700 Subject: [PATCH 64/75] address comments Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 43 +++++++++---------- python/ray/data/_internal/stats.py | 2 +- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 45347a80d534d..cd6b9663652d2 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -39,7 +39,7 @@ def iter_batches( stats: Optional[DatasetStats] = None, clear_block_after_read: bool = False, batch_size: Optional[int] = None, - batch_format: str = "default", + batch_format: Optional[str] = "default", drop_last: bool = False, collate_fn: Optional[Callable[[DataBatch], Any]] = None, shuffle_buffer_min_size: Optional[int] = None, @@ -94,7 +94,8 @@ def iter_batches( Specify "default" to use the current block format (promoting Arrow to pandas automatically), "pandas" to select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table``. Default is "default". + ``pyarrow.Table``, or None to use entire blocks + as batches. Default is "default". drop_last: Whether to drop the last batch if it's incomplete. collate_fn: A function to apply to each data batch before returning it. shuffle_buffer_min_size: If non-None, the data will be randomly shuffled using a @@ -131,22 +132,14 @@ def _async_iter_batches( block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], ) -> Iterator[DataBatch]: - if prefetch_batches > 0: - # Step 1: Prefetch logical batches locally. - block_refs = prefetch_batches_locally( - block_ref_iter=block_refs, - prefetcher=prefetcher, - num_batches_to_prefetch=prefetch_batches, - batch_size=batch_size, - eager_free=eager_free, - ) - else: - - def _drop_metadata(block_ref_iter): - for block_ref, metadata in block_ref_iter: - yield block_ref - - block_refs = _drop_metadata(block_refs) + # Step 1: Prefetch logical batches locally. + block_refs = prefetch_batches_locally( + block_ref_iter=block_refs, + prefetcher=prefetcher, + num_batches_to_prefetch=prefetch_batches, + batch_size=batch_size, + eager_free=eager_free, + ) # Step 2: Resolve the blocks. block_iter = resolve_block_refs(block_ref_iter=block_refs, stats=stats) @@ -193,9 +186,9 @@ def _drop_metadata(block_ref_iter): def _format_in_threadpool( batch_iter: Iterator[Batch], stats: DatasetStats, - batch_format: str = "default", - collate_fn: Optional[Callable[[DataBatch], Any]] = None, - num_threadpool_workers: int = 0, + batch_format: Optional[str], + collate_fn: Optional[Callable[[DataBatch], Any]], + num_threadpool_workers: int, ) -> Iterator[Batch]: """Executes the batching, formatting, and collation logic in a threadpool. @@ -206,7 +199,8 @@ def _format_in_threadpool( Specify "default" to use the current block format (promoting Arrow to pandas automatically), "pandas" to select ``pandas.DataFrame`` or "pyarrow" to select - ``pyarrow.Table``. Default is "default". + ``pyarrow.Table``, or None to use entire blocks + as batches. collate_fn: A function to apply to each data batch before returning it. num_threadpool_workers: The number of threads to use in the threadpool. """ @@ -258,6 +252,11 @@ def prefetch_batches_locally( sliding_window = collections.deque() current_window_size = 0 + if num_batches_to_prefetch <= 0: + for block_ref, metadata in block_ref_iter: + yield block_ref + return + if batch_size is not None: num_rows_to_prefetch = num_batches_to_prefetch * batch_size else: diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 6b84e7d002f6a..65627b575c425 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -687,7 +687,7 @@ def to_string(self) -> str: out += "* Num blocks local: {}\n".format(self.iter_blocks_local) out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) out += "* Num blocks unknown location: {}\n".format(self.iter_unknown_location) - out += "* Batch iteration time breakdown:\n" + out += "* Batch iteration time breakdown (summed across prefetch threads):\n" if self.get_time.get(): out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( fmt(self.get_time.min()), From 9ab19dceb43685e5602f766465cf20af9b56448e Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 23:21:35 -0700 Subject: [PATCH 65/75] default to 1 Signed-off-by: amogkam --- python/ray/data/dataset.py | 40 ++++++++++++++--------------- python/ray/data/dataset_iterator.py | 40 ++++++++++++++--------------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7679b8f817a69..6c1400fb4a390 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3001,7 +3001,7 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] def iter_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: Optional[int] = 256, batch_format: Optional[str] = "default", drop_last: bool = False, @@ -3024,9 +3024,9 @@ def iter_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -3070,7 +3070,7 @@ def iter_batches( def iter_torch_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, @@ -3107,9 +3107,9 @@ def iter_torch_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -3155,7 +3155,7 @@ def iter_torch_batches( def iter_tf_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: Optional[int] = 256, dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, drop_last: bool = False, @@ -3191,9 +3191,9 @@ def iter_tf_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -3237,7 +3237,7 @@ def to_torch( Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] ] = None, batch_size: int = 1, - prefetch_batches: int = 0, + prefetch_batches: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -3308,9 +3308,9 @@ def to_torch( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch @@ -3359,7 +3359,7 @@ def to_tf( feature_columns: Union[str, List[str]], label_columns: Union[str, List[str]], *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, @@ -3433,9 +3433,9 @@ def to_tf( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index dfd08269b35d3..2464af598c236 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -82,7 +82,7 @@ def _to_block_iterator( def iter_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: int = 256, batch_format: Optional[str] = "default", drop_last: bool = False, @@ -107,9 +107,9 @@ def iter_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -234,7 +234,7 @@ def schema(self) -> Union[type, "pyarrow.lib.Schema"]: def iter_torch_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: Optional[int] = 256, dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None, device: Optional[str] = None, @@ -267,9 +267,9 @@ def iter_torch_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -339,7 +339,7 @@ def collate_fn(batch: Union[np.ndarray, Dict[str, np.ndarray]]): def iter_tf_batches( self, *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: Optional[int] = 256, dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None, drop_last: bool = False, @@ -375,9 +375,9 @@ def iter_tf_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than ``batch_size`` rows if @@ -425,7 +425,7 @@ def to_torch( Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]] ] = None, batch_size: int = 1, - prefetch_batches: int = 0, + prefetch_batches: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, local_shuffle_seed: Optional[int] = None, @@ -496,9 +496,9 @@ def to_torch( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch @@ -614,7 +614,7 @@ def to_tf( feature_columns: Union[str, List[str]], label_columns: Union[str, List[str]], *, - prefetch_batches: int = 0, + prefetch_batches: int = 1, batch_size: int = 1, drop_last: bool = False, local_shuffle_buffer_size: Optional[int] = None, @@ -690,9 +690,9 @@ def to_tf( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 0 (no prefetching enabled.) This is still - an alpha API. You can revert back to the old prefetching behavior by - setting `use_legacy_iter_batches` to True in the DatasetContext. + the collate_fn. Defaults to 1. This is still an alpha API. You can + revert back to the old prefetching behavior by setting + `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If From 52d476720f266d7c19c14b89d07366c246bdb84e Mon Sep 17 00:00:00 2001 From: amogkam Date: Fri, 24 Mar 2023 23:22:27 -0700 Subject: [PATCH 66/75] default to 1 Signed-off-by: amogkam --- python/ray/data/tests/block_batching/test_iter_batches.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 53a0892569bbb..c43463ea5ccc6 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -105,7 +105,10 @@ def test_restore_from_original_order(): # 3. Block size is not divisble by batch size @pytest.mark.parametrize("batch_size", [1, 4, 3]) @pytest.mark.parametrize("drop_last", [True, False]) -def test_iter_batches_e2e(ray_start_regular_shared, batch_size, drop_last): +@pytest.mark.parametrize("prefetch_batches", [0, 1]) +def test_iter_batches_e2e( + ray_start_regular_shared, batch_size, drop_last, prefetch_batches +): def collate_fn(batch: pd.DataFrame): return batch + 1 @@ -117,6 +120,7 @@ def collate_fn(batch: pd.DataFrame): output_batches = iter_batches( block_refs_iter, batch_size=batch_size, + prefetch_batches=prefetch_batches, batch_format="pandas", collate_fn=collate_fn, drop_last=drop_last, From ad2dbb2a5aa6c07514a45a7babf8aab477073a15 Mon Sep 17 00:00:00 2001 From: amogkam Date: Sat, 25 Mar 2023 12:38:41 -0700 Subject: [PATCH 67/75] fixes Signed-off-by: amogkam --- python/ray/data/_internal/logical/operators/map_operator.py | 2 -- python/ray/data/_internal/planner/plan_udf_map_op.py | 1 - python/ray/data/dataset_iterator.py | 1 - 3 files changed, 4 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 44453f50f76ba..0f41501ab72c8 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -87,7 +87,6 @@ def __init__( fn: BatchUDF, batch_size: Optional[int] = DEFAULT_BATCH_SIZE, batch_format: Optional[str] = "default", - prefetch_batches: int = 0, zero_copy_batch: bool = False, fn_args: Optional[Iterable[Any]] = None, fn_kwargs: Optional[Dict[str, Any]] = None, @@ -111,7 +110,6 @@ def __init__( ) self._batch_size = batch_size self._batch_format = batch_format - self._prefetch_batches = prefetch_batches self._zero_copy_batch = zero_copy_batch diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 49a47ae76ffe8..561f00025beeb 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -34,7 +34,6 @@ def _plan_udf_map_op( transform_fn = generate_map_batches_fn( batch_size=op._batch_size, batch_format=op._batch_format, - prefetch_batches=op._prefetch_batches, zero_copy_batch=op._zero_copy_batch, ) elif isinstance(op, MapRows): diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index 2464af598c236..a3ce0141dc62c 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -182,7 +182,6 @@ def drop_metadata(block_iterator): shuffle_buffer_min_size=local_shuffle_buffer_size, shuffle_seed=local_shuffle_seed, prefetch_batches=prefetch_batches, - clear_block_after_read=True, ) if stats: From 72dc3fbaa9bb0f602512a4d4101c789f02b266a6 Mon Sep 17 00:00:00 2001 From: amogkam Date: Sat, 25 Mar 2023 12:58:40 -0700 Subject: [PATCH 68/75] fix test_stats Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 90 +++++++++++++++++------------ python/ray/data/tests/test_stats.py | 13 +++-- 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 65627b575c425..6715d4479db28 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -675,49 +675,67 @@ def __str__(self) -> str: def to_string(self) -> str: out = "" - out += "\nDataset iterator time breakdown:\n" - if self.block_time.get(): - out += "* Total time user code is blocked: {}\n".format( - fmt(self.block_time.get()) - ) - if self.user_time.get(): - out += "* Total time in user code: {}\n".format(fmt(self.user_time.get())) - if self.total_time.get(): - out += "* Total time overall: {}\n".format(fmt(self.total_time.get())) - out += "* Num blocks local: {}\n".format(self.iter_blocks_local) - out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) - out += "* Num blocks unknown location: {}\n".format(self.iter_unknown_location) - out += "* Batch iteration time breakdown (summed across prefetch threads):\n" - if self.get_time.get(): - out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( - fmt(self.get_time.min()), - fmt(self.get_time.max()), - fmt(self.get_time.avg()), - fmt(self.get_time.get()), - ) - if self.next_time.get(): - out += " * In batch creation: {} min, {} max, {} avg, {} total\n".format( - fmt(self.next_time.min()), - fmt(self.next_time.max()), - fmt(self.next_time.avg()), - fmt(self.next_time.get()), + if ( + self.block_time.get() + or self.total_time.get() + or self.get_time.get() + or self.next_time.get() + or self.format_time.get() + or self.collate_time.get() + ): + out += "\nDataset iterator time breakdown:\n" + if self.block_time.get(): + out += "* Total time user code is blocked: {}\n".format( + fmt(self.block_time.get()) + ) + if self.user_time.get(): + out += "* Total time in user code: {}\n".format( + fmt(self.user_time.get()) + ) + if self.total_time.get(): + out += "* Total time overall: {}\n".format(fmt(self.total_time.get())) + out += "* Num blocks local: {}\n".format(self.iter_blocks_local) + out += "* Num blocks remote: {}\n".format(self.iter_blocks_remote) + out += "* Num blocks unknown location: {}\n".format( + self.iter_unknown_location ) - if self.format_time.get(): out += ( - " * In batch formatting: {} min, {} max, {} avg, {} total\n".format( + "* Batch iteration time breakdown (summed across prefetch threads):\n" + ) + if self.get_time.get(): + out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( + fmt(self.get_time.min()), + fmt(self.get_time.max()), + fmt(self.get_time.avg()), + fmt(self.get_time.get()), + ) + if self.next_time.get(): + batch_creation_str = ( + " * In batch creation: {} min, {} max, " "{} avg, {} total\n" + ) + out += batch_creation_str.format( + fmt(self.next_time.min()), + fmt(self.next_time.max()), + fmt(self.next_time.avg()), + fmt(self.next_time.get()), + ) + if self.format_time.get(): + format_str = ( + " * In batch formatting: {} min, {} max, " "{} avg, {} total\n" + ) + out += format_str.format( fmt(self.format_time.min()), fmt(self.format_time.max()), fmt(self.format_time.avg()), fmt(self.format_time.get()), ) - ) - if self.collate_time.get(): - out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format( - fmt(self.collate_time.min()), - fmt(self.collate_time.max()), - fmt(self.collate_time.avg()), - fmt(self.collate_time.get()), - ) + if self.collate_time.get(): + out += " * In collate_fn: {} min, {} max, {} avg, {} total\n".format( + fmt(self.collate_time.min()), + fmt(self.collate_time.max()), + fmt(self.collate_time.avg()), + fmt(self.collate_time.get()), + ) return out diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 0c4dc3c843679..a763741b79d9e 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -914,15 +914,16 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_dataset_context) {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, 'obj_store_mem_peak': N} Dataset iterator time breakdown: -* In ray.wait(): T -* In ray.get(): T +* Total time user code is blocked: T +* Total time in user code: T +* Total time overall: T * Num blocks local: Z * Num blocks remote: Z * Num blocks unknown location: N -* In next_batch(): T -* In format_batch(): T -* In user code: T -* Total time: T +* Batch iteration time breakdown (summed across prefetch threads): + * In ray.get(): T min, T max, T avg, T total + * In batch creation: T min, T max, T avg, T total + * In batch formatting: T min, T max, T avg, T total """ ) From 33195a094aae2cf3e36bae9a931eaf447260e4bf Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:10:16 -0700 Subject: [PATCH 69/75] update Signed-off-by: amogkam --- .../pipelined_dataset_iterator.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py index af298243a539a..8cbb579a48a34 100644 --- a/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py +++ b/python/ray/data/_internal/dataset_iterator/pipelined_dataset_iterator.py @@ -1,8 +1,8 @@ -from typing import TYPE_CHECKING, Optional, Union, Iterator, Tuple +from typing import Any, TYPE_CHECKING, Callable, Optional, Union, Iterator, Tuple import warnings from ray.types import ObjectRef -from ray.data.block import Block, BlockMetadata +from ray.data.block import Block, BlockMetadata, DataBatch from ray.data.dataset_iterator import DatasetIterator from ray.data._internal.stats import DatasetStats @@ -42,6 +42,31 @@ def block_iter(): return block_iter(), None + def iter_batches( + self, + *, + prefetch_batches: int = 0, + batch_size: int = 256, + batch_format: Optional[str] = "default", + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + # Deprecated. + prefetch_blocks: int = 0, + ) -> Iterator[DataBatch]: + # Set prefetch_batches to default of 0 for DatasetPipeline. + return super().iter_batches( + prefetch_batches=prefetch_batches, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + local_shuffle_buffer_size=local_shuffle_buffer_size, + local_shuffle_seed=local_shuffle_seed, + _collate_fn=_collate_fn, + prefetch_blocks=prefetch_blocks, + ) + def stats(self) -> str: return self._base_dataset_pipeline.stats() From e3c79bafeeea0513a8b9b86f2f54728add60f0e8 Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:14:13 -0700 Subject: [PATCH 70/75] update Signed-off-by: amogkam --- .../_internal/block_batching/iter_batches.py | 2 +- python/ray/data/_internal/stats.py | 4 ++-- python/ray/data/dataset.py | 20 +++++++++---------- python/ray/data/dataset_iterator.py | 20 +++++++++---------- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index cd6b9663652d2..01de86c0934d5 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -45,7 +45,7 @@ def iter_batches( shuffle_buffer_min_size: Optional[int] = None, shuffle_seed: Optional[int] = None, ensure_copy: bool = False, - prefetch_batches: int = 0, + prefetch_batches: int = 1, ) -> Iterator[DataBatch]: """Create formatted batches of data from an iterator of block object references and corresponding metadata. diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 6715d4479db28..46b11480324ec 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -40,11 +40,11 @@ def __init__(self): @contextmanager def timer(self) -> None: - time_start = time.thread_time() + time_start = time.perf_counter() try: yield finally: - self.add(time.thread_time() - time_start) + self.add(time.perf_counter() - time_start) def add(self, value: float) -> None: self._value += value diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 359419dad149b..76d50afbe9244 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3024,8 +3024,8 @@ def iter_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -3107,8 +3107,8 @@ def iter_torch_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -3191,8 +3191,8 @@ def iter_tf_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -3308,8 +3308,8 @@ def to_torch( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If @@ -3433,8 +3433,8 @@ def to_tf( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index a3ce0141dc62c..eeb2e8cfb8bd1 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -107,8 +107,8 @@ def iter_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -266,8 +266,8 @@ def iter_torch_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -374,8 +374,8 @@ def iter_tf_batches( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). @@ -495,8 +495,8 @@ def to_torch( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If @@ -689,8 +689,8 @@ def to_tf( prefetch_batches: The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply - the collate_fn. Defaults to 1. This is still an alpha API. You can - revert back to the old prefetching behavior by setting + the collate_fn. Defaults to 1. You can revert back to the old + prefetching behavior that uses `prefetch_blocks` by setting `use_legacy_iter_batches` to True in the DatasetContext. batch_size: Record batch size. Defaults to 1. drop_last: Set to True to drop the last incomplete batch, From 2888f96282d2153d59ad5b64e6fc2ecce9cfd623 Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:16:31 -0700 Subject: [PATCH 71/75] fix iter_rows Signed-off-by: amogkam --- python/ray/data/dataset_iterator.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index eeb2e8cfb8bd1..fc07eb56e7794 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -144,7 +144,7 @@ def iter_batches( if prefetch_blocks > 0 and not use_legacy: raise DeprecationWarning( "`prefetch_blocks` arg is deprecated in Ray 2.4. Use " - "the`prefetch_batches` arg instead to specify the amount of " + "the `prefetch_batches` arg instead to specify the amount of " "prefetching in terms of batches instead of blocks. If you " "would like to use the legacy `iter_batches` codepath, " "you can enable it by setting `use_legacy_iter_batches` " @@ -209,13 +209,16 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] Returns: An iterator over rows of the dataset. """ - for batch in self.iter_batches( - batch_size=None, + iter_batch_args = {"batch_size": None, "batch_format": None} + + context = DatasetContext.get_current() + if context.use_legacy_iter_batches: + iter_batch_args["prefetch_blocks"] = prefetch_blocks + else: # If batch_size is None, 1 block is exactly 1 batch. - prefetch_batches=prefetch_blocks, - prefetch_blocks=prefetch_blocks, - batch_format=None, - ): + iter_batch_args["prefetch_batches"] = prefetch_blocks + + for batch in self.iter_batches(**iter_batch_args): batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch)) for row in batch.iter_rows(): yield row From 5f65e43b69036e5578c4ad426a5194296d36077f Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:16:56 -0700 Subject: [PATCH 72/75] comment Signed-off-by: amogkam --- python/ray/data/dataset_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/dataset_iterator.py b/python/ray/data/dataset_iterator.py index fc07eb56e7794..a79cfa41a9146 100644 --- a/python/ray/data/dataset_iterator.py +++ b/python/ray/data/dataset_iterator.py @@ -215,7 +215,7 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]] if context.use_legacy_iter_batches: iter_batch_args["prefetch_blocks"] = prefetch_blocks else: - # If batch_size is None, 1 block is exactly 1 batch. + # Since batch_size is None, 1 block is exactly 1 batch. iter_batch_args["prefetch_batches"] = prefetch_blocks for batch in self.iter_batches(**iter_batch_args): From d2958bd10b788b8a5de0312d7f874777f1f34030 Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:24:01 -0700 Subject: [PATCH 73/75] zero division Signed-off-by: amogkam --- python/ray/data/_internal/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 46b11480324ec..79bedc02ad6e4 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -64,7 +64,7 @@ def max(self) -> float: return self._max def avg(self) -> float: - return self._value / self._total_count + return self._value / self._total_count if self._total_count else float("inf") class _DatasetStatsBuilder: From 8308ee665cfbb8c6ca265e507c5056df9afef45f Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 12:36:04 -0700 Subject: [PATCH 74/75] fix Signed-off-by: amogkam --- python/ray/data/_internal/block_batching/block_batching.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index d0e7ab4a95218..414bc012e8d42 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -188,6 +188,9 @@ def _prefetch_blocks( if num_blocks_to_prefetch == 0: for block_ref in block_ref_iter: yield block_ref + trace_deallocation( + block_ref, "block_batching._prefetch_blocks", free=eager_free + ) window_size = num_blocks_to_prefetch # Create the initial set of blocks to prefetch. From 829311d66c135198fa999e868662669e9320207f Mon Sep 17 00:00:00 2001 From: amogkam Date: Mon, 27 Mar 2023 13:31:36 -0700 Subject: [PATCH 75/75] add init file Signed-off-by: amogkam --- python/ray/data/_internal/dataset_iterator/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/ray/data/_internal/dataset_iterator/__init__.py diff --git a/python/ray/data/_internal/dataset_iterator/__init__.py b/python/ray/data/_internal/dataset_iterator/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d