diff --git a/python/ray/data/_internal/fast_repartition.py b/python/ray/data/_internal/fast_repartition.py index 99878524f7e7..5c31da6311b2 100644 --- a/python/ray/data/_internal/fast_repartition.py +++ b/python/ray/data/_internal/fast_repartition.py @@ -3,6 +3,8 @@ import ray from ray.data._internal.block_list import BlockList from ray.data._internal.execution.interfaces import TaskContext +from ray.data._internal.logical.interfaces import LogicalPlan +from ray.data._internal.logical.operators.input_data_operator import InputData from ray.data._internal.plan import ExecutionPlan from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn @@ -11,9 +13,17 @@ from ray.data.block import BlockAccessor -def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None): +def fast_repartition( + blocks: BlockList, + num_blocks: int, + ctx: Optional[TaskContext] = None, +): + from ray.data._internal.execution.legacy_compat import _block_list_to_bundles from ray.data.dataset import Dataset, Schema + ref_bundles = _block_list_to_bundles(blocks, blocks._owned_by_consumer) + logical_plan = LogicalPlan(InputData(ref_bundles)) + wrapped_ds = Dataset( ExecutionPlan( blocks, @@ -22,6 +32,7 @@ def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None): ), 0, lazy=False, + logical_plan=logical_plan, ) # Compute the (n-1) indices needed for an equal split of the data. count = wrapped_ds.count() diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index b441c8bd7c84..7e1d64226459 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -40,6 +40,7 @@ from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle +from ray.data._internal.execution.legacy_compat import _block_list_to_bundles from ray.data._internal.iterator.iterator_impl import DataIteratorImpl from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator from ray.data._internal.lazy_block_list import LazyBlockList @@ -241,7 +242,7 @@ def __init__( The constructor is not part of the Dataset API. Use the ``ray.data.*`` read methods to construct a dataset. """ - assert isinstance(plan, ExecutionPlan) + assert isinstance(plan, ExecutionPlan), type(plan) usage_lib.record_library_usage("dataset") # Legacy telemetry name. if ray.util.log_once("strict_mode_explanation"): @@ -1228,20 +1229,29 @@ def split( if locality_hints is None: blocks = np.array_split(block_refs, n) meta = np.array_split(metadata, n) - return [ - MaterializedDataset( - ExecutionPlan( - BlockList( - b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer + + split_datasets = [] + for b, m in zip(blocks, meta): + block_list = BlockList( + b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer + ) + logical_plan = self._plan._logical_plan + if logical_plan is not None: + ref_bundles = _block_list_to_bundles(block_list, owned_by_consumer) + logical_plan = LogicalPlan(InputData(input_data=ref_bundles)) + split_datasets.append( + MaterializedDataset( + ExecutionPlan( + block_list, + stats, + run_by_consumer=owned_by_consumer, ), - stats, - run_by_consumer=owned_by_consumer, - ), - self._epoch, - self._lazy, + self._epoch, + self._lazy, + logical_plan, + ) ) - for b, m in zip(blocks, meta) - ] + return split_datasets metadata_mapping = {b: m for b, m in zip(block_refs, metadata)} @@ -1350,18 +1360,25 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: # equalize the splits per_split_block_lists = _equalize(per_split_block_lists, owned_by_consumer) - return [ - MaterializedDataset( - ExecutionPlan( - block_split, - stats, - run_by_consumer=owned_by_consumer, - ), - self._epoch, - self._lazy, + split_datasets = [] + for block_split in per_split_block_lists: + logical_plan = self._plan._logical_plan + if logical_plan is not None: + ref_bundles = _block_list_to_bundles(block_split, owned_by_consumer) + logical_plan = LogicalPlan(InputData(input_data=ref_bundles)) + split_datasets.append( + MaterializedDataset( + ExecutionPlan( + block_split, + stats, + run_by_consumer=owned_by_consumer, + ), + self._epoch, + self._lazy, + logical_plan, + ) ) - for block_split in per_split_block_lists - ] + return split_datasets @ConsumptionAPI def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]: @@ -1408,20 +1425,31 @@ def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]: split_duration = time.perf_counter() - start_time parent_stats = self._plan.stats() splits = [] + for bs, ms in zip(blocks, metadata): stats = DatasetStats(stages={"Split": ms}, parent=parent_stats) stats.time_total_s = split_duration + + split_block_list = BlockList( + bs, ms, owned_by_consumer=block_list._owned_by_consumer + ) + logical_plan = self._plan._logical_plan + if logical_plan is not None: + ref_bundles = _block_list_to_bundles( + split_block_list, block_list._owned_by_consumer + ) + logical_plan = LogicalPlan(InputData(input_data=ref_bundles)) + splits.append( MaterializedDataset( ExecutionPlan( - BlockList( - bs, ms, owned_by_consumer=block_list._owned_by_consumer - ), + split_block_list, stats, run_by_consumer=block_list._owned_by_consumer, ), self._epoch, self._lazy, + logical_plan, ) ) return splits diff --git a/python/ray/data/tests/test_split.py b/python/ray/data/tests/test_split.py index 257935c2c7a2..fa9895756d03 100644 --- a/python/ray/data/tests/test_split.py +++ b/python/ray/data/tests/test_split.py @@ -11,6 +11,9 @@ import ray from ray.data._internal.block_list import BlockList from ray.data._internal.equalize import _equalize +from ray.data._internal.execution.interfaces import RefBundle +from ray.data._internal.logical.interfaces import LogicalPlan +from ray.data._internal.logical.operators.input_data_operator import InputData from ray.data._internal.plan import ExecutionPlan from ray.data._internal.split import ( _drop_empty_block_split, @@ -76,24 +79,32 @@ def count(s): ([2, 5], 1), # Single split. ], ) -def test_equal_split_balanced(ray_start_regular_shared, block_sizes, num_splits): +def test_equal_split_balanced( + ray_start_regular_shared, enable_optimizer, block_sizes, num_splits +): _test_equal_split_balanced(block_sizes, num_splits) def _test_equal_split_balanced(block_sizes, num_splits): blocks = [] metadata = [] + ref_bundles = [] total_rows = 0 for block_size in block_sizes: block = pd.DataFrame({"id": list(range(total_rows, total_rows + block_size))}) blocks.append(ray.put(block)) metadata.append(BlockAccessor.for_block(block).get_metadata(None, None)) + blk = (blocks[-1], metadata[-1]) + ref_bundles.append(RefBundle((blk,), owns_blocks=True)) total_rows += block_size block_list = BlockList(blocks, metadata, owned_by_consumer=True) + + logical_plan = LogicalPlan(InputData(input_data=ref_bundles)) ds = Dataset( ExecutionPlan(block_list, DatasetStats.TODO(), run_by_consumer=True), 0, False, + logical_plan, ) splits = ds.split(num_splits, equal=True) @@ -111,7 +122,6 @@ def _test_equal_split_balanced(block_sizes, num_splits): def test_equal_split_balanced_grid(ray_start_regular_shared): - # Tests balanced equal splitting over a grid of configurations. # Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n seed = int(time.time()) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 5c7e26d0a2cf..9b35928029b3 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -575,7 +575,7 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path): * Output size bytes: N min, N max, N mean, N total * Tasks per node: N min, N max, N mean; N nodes used -Stage N Map: N/N blocks executed in T +Stage N Map(): N/N blocks executed in T * Remote wall time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total * Peak heap memory usage (MiB): N min, N max, N mean