From 303ac3b9b3df14456fd62997118ff42e7cb8c3a8 Mon Sep 17 00:00:00 2001 From: Jian Xiao <99709935+jianoaix@users.noreply.github.com> Date: Sat, 4 Mar 2023 23:03:27 -0800 Subject: [PATCH] [Datasets] Streaming executor fixes #5 (#32951) --- python/ray/data/_internal/lazy_block_list.py | 19 +++++++++---------- python/ray/data/_internal/util.py | 19 ++++++++++++++++++- python/ray/data/tests/test_dataset.py | 16 +++++++++++++--- python/ray/data/tests/test_util.py | 17 ++++++++++++++++- 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/python/ray/data/_internal/lazy_block_list.py b/python/ray/data/_internal/lazy_block_list.py index aebb5837298e..f4356aaee563 100644 --- a/python/ray/data/_internal/lazy_block_list.py +++ b/python/ray/data/_internal/lazy_block_list.py @@ -2,14 +2,13 @@ import uuid from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union -import numpy as np - import ray from ray.data._internal.block_list import BlockList from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.memory_tracing import trace_allocation from ray.data._internal.stats import DatasetStats, _get_or_create_stats_actor +from ray.data._internal.util import _split_list from ray.data.block import ( Block, BlockAccessor, @@ -162,22 +161,22 @@ def _check_if_cleared(self): # Note: does not force execution prior to splitting. def split(self, split_size: int) -> List["LazyBlockList"]: num_splits = math.ceil(len(self._tasks) / split_size) - tasks = np.array_split(self._tasks, num_splits) - block_partition_refs = np.array_split(self._block_partition_refs, num_splits) - block_partition_meta_refs = np.array_split( + tasks = _split_list(self._tasks, num_splits) + block_partition_refs = _split_list(self._block_partition_refs, num_splits) + block_partition_meta_refs = _split_list( self._block_partition_meta_refs, num_splits ) - cached_metadata = np.array_split(self._cached_metadata, num_splits) + cached_metadata = _split_list(self._cached_metadata, num_splits) output = [] for t, b, m, c in zip( tasks, block_partition_refs, block_partition_meta_refs, cached_metadata ): output.append( LazyBlockList( - t.tolist(), - b.tolist(), - m.tolist(), - c.tolist(), + t, + b, + m, + c, owned_by_consumer=self._owned_by_consumer, ) ) diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index f681f9735b7e..f947e44fd763 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -1,7 +1,7 @@ import importlib import logging import os -from typing import List, Union, Optional, TYPE_CHECKING +from typing import Any, List, Union, Optional, TYPE_CHECKING from types import ModuleType import sys @@ -380,3 +380,20 @@ def ConsumptionAPI(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): return _consumption_api()(args[0]) return _consumption_api(*args, **kwargs) + + +def _split_list(arr: List[Any], num_splits: int) -> List[List[Any]]: + """Split the list into `num_splits` lists. + + The splits will be even if the `num_splits` divides the length of list, otherwise + the remainder (suppose it's R) will be allocated to the first R splits (one for + each). + This is the same as numpy.array_split(). The reason we make this a separate + implementation is to allow the heterogeneity in the elements in the list. + """ + assert num_splits > 0 + q, r = divmod(len(arr), num_splits) + splits = [ + arr[i * q + min(i, r) : (i + 1) * q + min(i + 1, r)] for i in range(num_splits) + ] + return splits diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index fb61d2f520aa..aab3e64ef921 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -4581,9 +4581,19 @@ def test_warning_execute_with_no_cpu(ray_start_cluster): ds = ray.data.range(10) ds = ds.map_batches(lambda x: x) ds.take() - except LoggerWarningCalled: - logger_args, logger_kwargs = mock_logger.call_args - assert "Warning: The Ray cluster currently does not have " in logger_args[0] + except Exception as e: + if ray.data.context.DatasetContext.get_current().use_streaming_executor: + assert isinstance(e, ValueError) + assert "exceeds the execution limits ExecutionResources(cpu=0.0" in str( + e + ) + else: + assert isinstance(e, LoggerWarningCalled) + logger_args, logger_kwargs = mock_logger.call_args + assert ( + "Warning: The Ray cluster currently does not have " + in logger_args[0] + ) def test_nowarning_execute_with_cpu(ray_start_cluster_init): diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 075569226474..f9bd0afbbc5d 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -2,7 +2,7 @@ import ray import numpy as np -from ray.data._internal.util import _check_pyarrow_version +from ray.data._internal.util import _check_pyarrow_version, _split_list from ray.data._internal.memory_tracing import ( trace_allocation, trace_deallocation, @@ -72,6 +72,21 @@ def test_memory_tracing(enabled): assert "test5" not in report, report +def test_list_splits(): + with pytest.raises(AssertionError): + _split_list(list(range(5)), 0) + + with pytest.raises(AssertionError): + _split_list(list(range(5)), -1) + + assert _split_list(list(range(5)), 7) == [[0], [1], [2], [3], [4], [], []] + assert _split_list(list(range(5)), 2) == [[0, 1, 2], [3, 4]] + assert _split_list(list(range(6)), 2) == [[0, 1, 2], [3, 4, 5]] + assert _split_list(list(range(5)), 1) == [[0, 1, 2, 3, 4]] + assert _split_list(["foo", 1, [0], None], 2) == [["foo", 1], [[0], None]] + assert _split_list(["foo", 1, [0], None], 3) == [["foo", 1], [[0]], [None]] + + if __name__ == "__main__": import sys