From ed560a7004ca5d70a6c6f5b19414644507e58aeb Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 10 Mar 2023 10:48:17 -0800 Subject: [PATCH] [data] [streaming] Add streaming_split() API (#32991) Signed-off-by: Jack He --- .../data/_internal/execution/interfaces.py | 9 + .../data/_internal/execution/legacy_compat.py | 31 ++- .../execution/operators/output_splitter.py | 90 ++++++-- .../execution/streaming_executor_state.py | 17 +- .../stream_split_dataset_iterator.py | 199 ++++++++++++++++++ python/ray/data/_internal/util.py | 2 +- python/ray/data/dataset.py | 47 ++++- python/ray/data/tests/test_operators.py | 42 ++++ .../ray/data/tests/test_streaming_executor.py | 31 ++- .../data/tests/test_streaming_integration.py | 45 ++++ 10 files changed, 485 insertions(+), 28 deletions(-) create mode 100644 python/ray/data/_internal/stream_split_dataset_iterator.py diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index 40a71a04f674..6b52c90df210 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -364,6 +364,15 @@ def get_work_refs(self) -> List[ray.ObjectRef]: """ return [] + def throttling_disabled(self) -> bool: + """Whether to disable resource throttling for this operator. + + This should return True for operators that only manipulate bundle metadata + (e.g., the OutputSplitter operator). This hints to the execution engine that + these operators should not be throttled based on resource usage. + """ + return False + def num_active_work_refs(self) -> int: """Return the number of active work refs. diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 9d01e88b0b18..ddacb8051030 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -41,21 +41,43 @@ def execute_to_legacy_block_iterator( allow_clear_input_blocks: bool, dataset_uuid: str, ) -> Iterator[ObjectRef[Block]]: - """Execute a plan with the new executor and return a block iterator. + """Same as execute_to_legacy_bundle_iterator but returning blocks.""" + bundle_iter = execute_to_legacy_bundle_iterator( + executor, plan, allow_clear_input_blocks, dataset_uuid + ) + for bundle in bundle_iter: + for block, _ in bundle.blocks: + yield block + + +def execute_to_legacy_bundle_iterator( + executor: Executor, + plan: ExecutionPlan, + allow_clear_input_blocks: bool, + dataset_uuid: str, + dag_rewrite=None, +) -> Iterator[RefBundle]: + """Execute a plan with the new executor and return a bundle iterator. Args: executor: The executor to use. plan: The legacy plan to execute. allow_clear_input_blocks: Whether the executor may consider clearing blocks. dataset_uuid: UUID of the dataset for this execution. + dag_rewrite: Callback that can be used to mutate the DAG prior to execution. + This is currently used as a legacy hack to inject the OutputSplit operator + for `Dataset.streaming_split()`. Returns: - The output as a block iterator. + The output as a bundle iterator. """ + if DatasetContext.get_current().optimizer_enabled: dag, stats = get_execution_plan(plan._logical_plan).dag, None else: dag, stats = _to_operator_dag(plan, allow_clear_input_blocks) + if dag_rewrite: + dag = dag_rewrite(dag) # Enforce to preserve ordering if the plan has stages required to do so, such as # Zip and Sort. @@ -64,10 +86,7 @@ def execute_to_legacy_block_iterator( executor._options.preserve_order = True bundle_iter = executor.execute(dag, initial_stats=stats) - - for bundle in bundle_iter: - for block, _ in bundle.blocks: - yield block + return bundle_iter def execute_to_legacy_block_list( diff --git a/python/ray/data/_internal/execution/operators/output_splitter.py b/python/ray/data/_internal/execution/operators/output_splitter.py index c8e24c23f7f3..b75fcce0e09f 100644 --- a/python/ray/data/_internal/execution/operators/output_splitter.py +++ b/python/ray/data/_internal/execution/operators/output_splitter.py @@ -1,5 +1,5 @@ import math -from typing import List, Dict +from typing import List, Dict, Optional from ray.data.block import Block, BlockMetadata, BlockAccessor from ray.data._internal.remote_fn import cached_remote_fn @@ -7,7 +7,9 @@ from ray.data._internal.execution.interfaces import ( RefBundle, PhysicalOperator, + ExecutionOptions, ExecutionResources, + NodeIdStr, ) from ray.types import ObjectRef @@ -17,10 +19,12 @@ class OutputSplitter(PhysicalOperator): The output bundles of this operator will have a `bundle.output_split_idx` attr set to an integer from [0..n-1]. This operator tries to divide the rows evenly - across output splits. + across output splits. If the `equal` option is set, the operator will furthermore + guarantee an exact split of rows across outputs, truncating the Dataset as needed. - If the `equal` option is set, the operator will furthermore guarantee an exact - split of rows across outputs, truncating the Dataset as needed. + Implementation wise, this operator keeps an internal buffer of bundles. The buffer + has a minimum size calculated to enable a good locality hit rate, as well as ensure + we can satisfy the `equal` requirement. OutputSplitter does not provide any ordering guarantees. """ @@ -30,6 +34,7 @@ def __init__( input_op: PhysicalOperator, n: int, equal: bool, + locality_hints: Optional[List[NodeIdStr]] = None, ): super().__init__(f"split({n}, equal={equal})", [input_op]) self._equal = equal @@ -40,6 +45,40 @@ def __init__( # The number of rows output to each output split so far. self._num_output: List[int] = [0 for _ in range(n)] + if locality_hints is not None: + if n != len(locality_hints): + raise ValueError( + "Locality hints list must have length `n`: " + f"len({locality_hints}) != {n}" + ) + self._locality_hints = locality_hints + if locality_hints: + # To optimize locality, we should buffer a certain number of elements + # internally before dispatch to allow the locality algorithm a good chance + # of selecting a preferred location. We use a small multiple of `n` since + # it's reasonable to buffer a couple blocks per consumer. + self._min_buffer_size = 2 * n + else: + self._min_buffer_size = 0 + self._locality_hits = 0 + self._locality_misses = 0 + + def start(self, options: ExecutionOptions) -> None: + super().start(options) + # Force disable locality optimization. + if not options.actor_locality_enabled: + self._locality_hints = None + self._min_buffer_size = 0 + + def throttling_disabled(self) -> bool: + """Disables resource-based throttling. + + It doesn't make sense to throttle the inputs to this operator, since all that + would do is lower the buffer size and prevent us from emitting outputs / + reduce the locality hit rate. + """ + return True + def has_next(self) -> bool: return len(self._output_queue) > 0 @@ -64,8 +103,8 @@ def add_input(self, bundle, input_index) -> None: def inputs_done(self) -> None: super().inputs_done() if not self._equal: - # There shouldn't be any buffered data if we're not in equal split mode. - assert not self._buffer + self._dispatch_bundles(dispatch_all=True) + assert not self._buffer, "Should have dispatched all bundles." return # Otherwise: @@ -101,21 +140,31 @@ def current_resource_usage(self) -> ExecutionResources: ) def progress_str(self) -> str: - if self._equal: - return f"{len(self._buffer)} buffered" - assert not self._buffer - return "" + if self._locality_hints: + return ( + f"[{self._locality_hits} locality hits, {self._locality_misses} misses]" + ) + else: + return "[locality disabled]" - def _dispatch_bundles(self) -> None: + def _dispatch_bundles(self, dispatch_all: bool = False) -> None: # Dispatch all dispatchable bundles from the internal buffer. # This may not dispatch all bundles when equal=True. - while self._buffer: + while self._buffer and ( + dispatch_all or len(self._buffer) >= self._min_buffer_size + ): target_index = self._select_output_index() target_bundle = self._pop_bundle_to_dispatch(target_index) if self._can_safely_dispatch(target_index, target_bundle.num_rows()): target_bundle.output_split_idx = target_index self._num_output[target_index] += target_bundle.num_rows() self._output_queue.append(target_bundle) + if self._locality_hints: + preferred_loc = self._locality_hints[target_index] + if self._get_location(target_bundle) == preferred_loc: + self._locality_hits += 1 + else: + self._locality_misses += 1 else: # Put it back and abort. self._buffer.insert(0, target_bundle) @@ -127,7 +176,12 @@ def _select_output_index(self) -> int: return i def _pop_bundle_to_dispatch(self, target_index: int) -> RefBundle: - # TODO implement locality aware bundle selection. + if self._locality_hints: + preferred_loc = self._locality_hints[target_index] + for bundle in self._buffer: + if self._get_location(bundle) == preferred_loc: + self._buffer.remove(bundle) + return bundle return self._buffer.pop(0) def _can_safely_dispatch(self, target_index: int, nrow: int) -> bool: @@ -164,6 +218,16 @@ def _split_from_buffer(self, nrow: int) -> List[RefBundle]: assert sum(b.num_rows() for b in output) == nrow, (acc, nrow) return output + def _get_location(self, bundle: RefBundle) -> Optional[NodeIdStr]: + """Ask Ray for the node id of the given bundle. + + This method may be overriden for testing. + + Returns: + A node id associated with the bundle, or None if unknown. + """ + return bundle.get_cached_location() + def _split(bundle: RefBundle, left_size: int) -> (RefBundle, RefBundle): left_blocks, left_meta = [], [] diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index f6ae5f87dfac..34530ed145a6 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -159,11 +159,11 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> MaybeRefBundle # Scan the queue and look for outputs tagged for the given index. for i in range(len(self.outqueue)): bundle = self.outqueue[i] - if bundle is None: + if bundle is None or isinstance(bundle, Exception): # End of stream for this index! Note that we # do not remove the None, so that it can act # as the termination signal for all indices. - return None + return bundle elif bundle.output_split_idx == output_split_idx: self.outqueue.remove(bundle) return bundle @@ -324,9 +324,14 @@ def select_operator_to_run( if not ops: return None - # Equally penalize outqueue length and num bundles processing for backpressure. + # Run metadata-only operators first. After that, equally penalize outqueue length + # and num bundles processing for backpressure. return min( - ops, key=lambda op: len(topology[op].outqueue) + topology[op].num_processing() + ops, + key=lambda op: ( + not op.throttling_disabled(), + len(topology[op].outqueue) + topology[op].num_processing(), + ), ) @@ -353,6 +358,10 @@ def _execution_allowed( Returns: Whether the op is allowed to run. """ + + if op.throttling_disabled(): + return True + assert isinstance(global_usage, TopologyResourceUsage), global_usage # To avoid starvation problems when dealing with fractional resource types, # convert all quantities to integer (0 or 1) for deciding admissibility. This diff --git a/python/ray/data/_internal/stream_split_dataset_iterator.py b/python/ray/data/_internal/stream_split_dataset_iterator.py new file mode 100644 index 000000000000..c7249d0c4db8 --- /dev/null +++ b/python/ray/data/_internal/stream_split_dataset_iterator.py @@ -0,0 +1,199 @@ +import copy +import logging +import sys +import threading +from typing import ( + List, + Dict, + Optional, + Iterator, + Callable, + Any, + Union, + TYPE_CHECKING, +) + +import ray + +from ray.data.dataset_iterator import DatasetIterator +from ray.data.block import Block, DataBatch +from ray.data.context import DatasetContext +from ray.data._internal.execution.streaming_executor import StreamingExecutor +from ray.data._internal.execution.legacy_compat import ( + execute_to_legacy_bundle_iterator, +) +from ray.data._internal.block_batching import batch_block_refs +from ray.data._internal.execution.operators.output_splitter import OutputSplitter +from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle +from ray.types import ObjectRef +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +if TYPE_CHECKING: + import pyarrow + from ray.data import Dataset + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +logger = logging.getLogger(__name__) + + +class StreamSplitDatasetIterator(DatasetIterator): + """Implements a collection of iterators over a shared data stream.""" + + @staticmethod + def create( + base_dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ) -> List["StreamSplitDatasetIterator"]: + """Create a split iterator from the given base Dataset and options. + + See also: `Dataset.streaming_split`. + """ + ctx = DatasetContext.get_current() + + # To avoid deadlock, the concurrency on this actor must be set to at least `n`. + coord_actor = SplitCoordinator.options( + max_concurrency=n, + scheduling_strategy=NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), soft=False + ), + ).remote(ctx, base_dataset, n, equal, locality_hints) + + return [ + StreamSplitDatasetIterator(base_dataset, coord_actor, i) for i in range(n) + ] + + def __init__( + self, + base_dataset: "Dataset", + coord_actor: ray.actor.ActorHandle, + output_split_idx: int, + ): + self._base_dataset = base_dataset + self._coord_actor = coord_actor + self._output_split_idx = output_split_idx + + def iter_batches( + self, + *, + prefetch_blocks: int = 0, + batch_size: int = 256, + batch_format: Literal["default", "numpy", "pandas"] = "default", + drop_last: bool = False, + local_shuffle_buffer_size: Optional[int] = None, + local_shuffle_seed: Optional[int] = None, + _collate_fn: Optional[Callable[[DataBatch], Any]] = None, + ) -> Iterator[DataBatch]: + """Implements DatasetIterator.""" + + def gen_blocks() -> Iterator[ObjectRef[Block]]: + future: ObjectRef[ + Optional[ObjectRef[Block]] + ] = self._coord_actor.get.remote(self._output_split_idx) + while True: + block_ref: Optional[ObjectRef[Block]] = ray.get(future) + if not block_ref: + break + else: + future = self._coord_actor.get.remote(self._output_split_idx) + yield block_ref + + yield from batch_block_refs( + gen_blocks(), + stats=None, + prefetch_blocks=prefetch_blocks, + batch_size=batch_size, + batch_format=batch_format, + drop_last=drop_last, + collate_fn=_collate_fn, + shuffle_buffer_min_size=local_shuffle_buffer_size, + shuffle_seed=local_shuffle_seed, + ) + + def stats(self) -> str: + """Implements DatasetIterator.""" + return self._base_dataset.stats() + + def schema(self) -> Union[type, "pyarrow.lib.Schema"]: + """Implements DatasetIterator.""" + return self._base_dataset.schema() + + +@ray.remote(num_cpus=0) +class SplitCoordinator: + """Coordinator actor for routing blocks to output splits. + + This actor runs a streaming executor locally on its main thread. Clients can + retrieve results via actor calls running on other threads. + """ + + def __init__( + self, + ctx: DatasetContext, + dataset: "Dataset", + n: int, + equal: bool, + locality_hints: Optional[List[NodeIdStr]], + ): + # Automatically set locality with output to the specified location hints. + if locality_hints: + ctx.execution_options.locality_with_output = locality_hints + logger.info(f"Auto configuring locality_with_output={locality_hints}") + + DatasetContext._set_current(ctx) + self._base_dataset = dataset + self._n = n + self._equal = equal + self._locality_hints = locality_hints + self._finished = False + self._lock = threading.RLock() + # Guarded by self._lock. + self._next_bundle: Dict[int, RefBundle] = {} + + executor = StreamingExecutor(copy.deepcopy(ctx.execution_options)) + + def add_split_op(dag): + return OutputSplitter(dag, n, equal, locality_hints) + + self._output_iterator = execute_to_legacy_bundle_iterator( + executor, + dataset._plan, + True, + dataset._plan._dataset_uuid, + dag_rewrite=add_split_op, + ) + + def get(self, output_split_idx: int) -> Optional[ObjectRef[Block]]: + """Blocking get operation. + + This is intended to be called concurrently from multiple clients. + """ + try: + # Ensure there is at least one bundle. + with self._lock: + if output_split_idx in self._next_bundle: + next_bundle = self._next_bundle[output_split_idx] + else: + next_bundle = None + + # Fetch next bundle if needed. + if next_bundle is None: + # This is a BLOCKING call, so do it outside the lock. + next_bundle = self._output_iterator.get_next(output_split_idx) + + block = next_bundle.blocks.pop()[0] + + # Accumulate any remaining blocks in next_bundle map as needed. + with self._lock: + self._next_bundle[output_split_idx] = next_bundle + if not next_bundle.blocks: + del self._next_bundle[output_split_idx] + + return block + except StopIteration: + return None diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index f947e44fd763..8274423deaff 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -343,7 +343,7 @@ def _consumption_api( """ base = ( " will trigger execution of the lazy transformations performed on " - "this dataset, and will block until execution completes." + "this dataset." ) if delegate: message = delegate + base diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index ad7335d910b3..58314b05d4ce 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -52,6 +52,7 @@ from ray.data.dataset_iterator import DatasetIterator from ray.data._internal.block_list import BlockList from ray.data._internal.dataset_iterator_impl import DatasetIteratorImpl +from ray.data._internal.stream_split_dataset_iterator import StreamSplitDatasetIterator from ray.data._internal.compute import ( ActorPoolStrategy, CallableClass, @@ -147,7 +148,7 @@ from ray.data.dataset_pipeline import DatasetPipeline from ray.data.grouped_dataset import GroupedDataset - from ray.data._internal.execution.interfaces import Executor + from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType @@ -1145,6 +1146,44 @@ def process_batch(batch): return self.map_batches(process_batch) + @ConsumptionAPI + def streaming_split( + self, + n: int, + *, + equal: bool = False, + locality_hints: Optional[List["NodeIdStr"]] = None, + ) -> List[DatasetIterator]: + """Returns ``n`` :class:`~ray.data.DatasetIterator`s that can be used to read + disjoint subsets of the dataset in parallel. + + This method is the recommended way to consume Datasets from multiple processes + (e.g., for distributed training). It requires streaming execution mode. + + The returned iterators are Ray-serializable and can be freely passed to any + Ray task or actor. + + Examples: + >>> import ray + >>> ds = ray.data.range(1000000) + >>> it1, it2 = ds.streaming_split(2, equal=True) + >>> list(it1.iter_batches()) # doctest: +SKIP + >>> list(it2.iter_batches()) # doctest: +SKIP + + Args: + n: Number of output iterators to return. + equal: If True, each output iterator will see an exactly equal number + of rows, dropping data if necessary. If False, some iterators may see + slightly more or less rows than other, but no data will be dropped. + locality_hints: Specify the node ids corresponding to each iterator + location. Datasets will try to minimize data movement based on the + iterator output locations. This list must have length ``n``. + + Returns: + The output iterator splits. + """ + return StreamSplitDatasetIterator.create(self, n, equal, locality_hints) + @ConsumptionAPI def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None @@ -1165,7 +1204,8 @@ def split( Time complexity: O(1) - See also: ``Dataset.split_at_indices``, ``Dataset.split_proportionately`` + See also: ``Dataset.split_at_indices``, ``Dataset.split_proportionately``, + and ``Dataset.streaming_split``. Args: n: Number of child datasets to return. @@ -1366,7 +1406,8 @@ def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: Time complexity: O(num splits) - See also: ``Dataset.split``, ``Dataset.split_proportionately`` + See also: ``Dataset.split_at_indices``, ``Dataset.split_proportionately``, + and ``Dataset.streaming_split``. Args: indices: List of sorted integers which indicate where the dataset diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index 104e5d422315..0fb9b8813d51 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -261,6 +261,48 @@ def test_split_operator_random(ray_start_regular_shared, equal, random_seed): assert sum(len(output_splits[i]) for i in range(3)) == num_inputs, output_splits +def test_split_operator_locality_hints(ray_start_regular_shared): + input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(10)])) + op = OutputSplitter(input_op, 2, equal=False, locality_hints=["node1", "node2"]) + + def get_fake_loc(item): + if item in [0, 1, 4, 5, 8]: + return "node1" + else: + return "node2" + + def get_bundle_loc(bundle): + return get_fake_loc(ray.get(bundle.blocks[0][0])[0]) + + op._get_location = get_bundle_loc + + # Feed data and implement streaming exec. + output_splits = collections.defaultdict(list) + op.start(ExecutionOptions()) + while input_op.has_next(): + op.add_input(input_op.get_next(), 0) + op.inputs_done() + while op.has_next(): + ref = op.get_next() + assert ref.owns_blocks, ref + for block, _ in ref.blocks: + output_splits[ref.output_split_idx].extend(ray.get(block)) + + total = 0 + for i in range(2): + if i == 0: + node = "node1" + else: + node = "node2" + split = output_splits[i] + for item in split: + assert get_fake_loc(item) == node + total += 1 + + assert total == 10, total + assert "10 locality hits, 0 misses" in op.progress_str() + + def test_map_operator_actor_locality_stats(ray_start_regular_shared): # Create with inputs. input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(100)])) diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index e1431856b6cb..cf1f9f6d5eff 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -125,7 +125,7 @@ def test_select_operator_to_run(): o3.num_active_work_refs = MagicMock(return_value=2) o3.internal_queue_size = MagicMock(return_value=0) assert select_operator_to_run(topo, NO_USAGE, ExecutionResources(), True) == o2 - # nternal queue size is added to num active tasks. + # Internal queue size is added to num active tasks. o3.num_active_work_refs = MagicMock(return_value=0) o3.internal_queue_size = MagicMock(return_value=2) assert select_operator_to_run(topo, NO_USAGE, ExecutionResources(), True) == o2 @@ -136,6 +136,10 @@ def test_select_operator_to_run(): o2.internal_queue_size = MagicMock(return_value=2) assert select_operator_to_run(topo, NO_USAGE, ExecutionResources(), True) == o3 + # Test prioritization of nothrottle ops. + o2.throttling_disabled = MagicMock(return_value=True) + assert select_operator_to_run(topo, NO_USAGE, ExecutionResources(), True) == o2 + def test_dispatch_next_task(): inputs = make_ref_bundles([[x] for x in range(20)]) @@ -400,6 +404,31 @@ def test_execution_allowed_downstream_aware_memory_throttling(): ) +def test_execution_allowed_nothrottle(): + op = InputDataBuffer([]) + op.incremental_resource_usage = MagicMock(return_value=ExecutionResources()) + # Above global. + assert not _execution_allowed( + op, + TopologyResourceUsage( + ExecutionResources(object_store_memory=1000), + {op: DownstreamMemoryInfo(1, 1000)}, + ), + ExecutionResources(object_store_memory=900), + ) + + # Throttling disabled. + op.throttling_disabled = MagicMock(return_value=True) + assert _execution_allowed( + op, + TopologyResourceUsage( + ExecutionResources(object_store_memory=1000), + {op: DownstreamMemoryInfo(1, 1000)}, + ), + ExecutionResources(object_store_memory=900), + ) + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index 0f623309a5f2..69b8ae72c577 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -88,6 +88,51 @@ def run(self): assert len(c1.out) == 10, c0.out +def test_streaming_split_e2e(ray_start_10_cpus_shared): + def get_lengths(*iterators): + lengths = [] + for it in iterators: + x = 0 + for batch in it.iter_batches(): + x += len(batch) + lengths.append(x) + lengths.sort() + return lengths + + ds = ray.data.range(1000) + ( + i1, + i2, + ) = ds.streaming_split(2, equal=True) + lengths = get_lengths(i1, i2) + assert lengths == [500, 500], lengths + + ds = ray.data.range(1) + ( + i1, + i2, + ) = ds.streaming_split(2, equal=True) + lengths = get_lengths(i1, i2) + assert lengths == [0, 0], lengths + + ds = ray.data.range(1) + ( + i1, + i2, + ) = ds.streaming_split(2, equal=False) + lengths = get_lengths(i1, i2) + assert lengths == [0, 1], lengths + + ds = ray.data.range(1000, parallelism=10) + i1, i2, i3 = ds.streaming_split(3, equal=True) + lengths = get_lengths(i1, i2, i3) + assert lengths == [333, 333, 333], lengths + + i1, i2, i3 = ds.streaming_split(3, equal=False) + lengths = get_lengths(i1, i2, i3) + assert lengths == [300, 300, 400], lengths + + def test_e2e_option_propagation(ray_start_10_cpus_shared, restore_dataset_context): DatasetContext.get_current().new_execution_backend = True DatasetContext.get_current().use_streaming_executor = True