Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Implement optimizer with Dataset.split() #36363

Merged
merged 4 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/ray/data/_internal/fast_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.input_data_operator import InputData
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
Expand All @@ -11,9 +13,17 @@
from ray.data.block import BlockAccessor


def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None):
def fast_repartition(
blocks: BlockList,
num_blocks: int,
ctx: Optional[TaskContext] = None,
):
from ray.data._internal.execution.legacy_compat import _block_list_to_bundles
from ray.data.dataset import Dataset, Schema

ref_bundles = _block_list_to_bundles(blocks, blocks._owned_by_consumer)
logical_plan = LogicalPlan(InputData(ref_bundles))

wrapped_ds = Dataset(
ExecutionPlan(
blocks,
Expand All @@ -22,6 +32,7 @@ def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None):
),
0,
lazy=False,
logical_plan=logical_plan,
)
# Compute the (n-1) indices needed for an equal split of the data.
count = wrapped_ds.count()
Expand Down
82 changes: 55 additions & 27 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.equalize import _equalize
from ray.data._internal.execution.interfaces import RefBundle
from ray.data._internal.execution.legacy_compat import _block_list_to_bundles
from ray.data._internal.iterator.iterator_impl import DataIteratorImpl
from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
from ray.data._internal.lazy_block_list import LazyBlockList
Expand Down Expand Up @@ -241,7 +242,7 @@ def __init__(
The constructor is not part of the Dataset API. Use the ``ray.data.*``
read methods to construct a dataset.
"""
assert isinstance(plan, ExecutionPlan)
assert isinstance(plan, ExecutionPlan), type(plan)
usage_lib.record_library_usage("dataset") # Legacy telemetry name.

if ray.util.log_once("strict_mode_explanation"):
Expand Down Expand Up @@ -1228,20 +1229,29 @@ def split(
if locality_hints is None:
blocks = np.array_split(block_refs, n)
meta = np.array_split(metadata, n)
return [
MaterializedDataset(
ExecutionPlan(
BlockList(
b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer

split_datasets = []
for b, m in zip(blocks, meta):
block_list = BlockList(
b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer
)
logical_plan = self._plan._logical_plan
if logical_plan is not None:
ref_bundles = _block_list_to_bundles(block_list, owned_by_consumer)
logical_plan = LogicalPlan(InputData(input_data=ref_bundles))
split_datasets.append(
MaterializedDataset(
ExecutionPlan(
block_list,
stats,
run_by_consumer=owned_by_consumer,
),
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
self._epoch,
self._lazy,
logical_plan,
)
)
for b, m in zip(blocks, meta)
]
return split_datasets

metadata_mapping = {b: m for b, m in zip(block_refs, metadata)}

Expand Down Expand Up @@ -1350,18 +1360,25 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]:
# equalize the splits
per_split_block_lists = _equalize(per_split_block_lists, owned_by_consumer)

return [
MaterializedDataset(
ExecutionPlan(
block_split,
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
split_datasets = []
for block_split in per_split_block_lists:
logical_plan = self._plan._logical_plan
if logical_plan is not None:
ref_bundles = _block_list_to_bundles(block_split, owned_by_consumer)
logical_plan = LogicalPlan(InputData(input_data=ref_bundles))
split_datasets.append(
MaterializedDataset(
ExecutionPlan(
block_split,
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
logical_plan,
)
)
for block_split in per_split_block_lists
]
return split_datasets

@ConsumptionAPI
def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]:
Expand Down Expand Up @@ -1408,20 +1425,31 @@ def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]:
split_duration = time.perf_counter() - start_time
parent_stats = self._plan.stats()
splits = []

for bs, ms in zip(blocks, metadata):
stats = DatasetStats(stages={"Split": ms}, parent=parent_stats)
stats.time_total_s = split_duration

split_block_list = BlockList(
bs, ms, owned_by_consumer=block_list._owned_by_consumer
)
logical_plan = self._plan._logical_plan
if logical_plan is not None:
ref_bundles = _block_list_to_bundles(
split_block_list, block_list._owned_by_consumer
)
logical_plan = LogicalPlan(InputData(input_data=ref_bundles))

splits.append(
MaterializedDataset(
ExecutionPlan(
BlockList(
bs, ms, owned_by_consumer=block_list._owned_by_consumer
),
split_block_list,
stats,
run_by_consumer=block_list._owned_by_consumer,
),
self._epoch,
self._lazy,
logical_plan,
)
)
return splits
Expand Down
14 changes: 12 additions & 2 deletions python/ray/data/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.equalize import _equalize
from ray.data._internal.execution.interfaces import RefBundle
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.input_data_operator import InputData
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.split import (
_drop_empty_block_split,
Expand Down Expand Up @@ -76,24 +79,32 @@ def count(s):
([2, 5], 1), # Single split.
],
)
def test_equal_split_balanced(ray_start_regular_shared, block_sizes, num_splits):
def test_equal_split_balanced(
ray_start_regular_shared, enable_optimizer, block_sizes, num_splits
):
_test_equal_split_balanced(block_sizes, num_splits)


def _test_equal_split_balanced(block_sizes, num_splits):
blocks = []
metadata = []
ref_bundles = []
total_rows = 0
for block_size in block_sizes:
block = pd.DataFrame({"id": list(range(total_rows, total_rows + block_size))})
blocks.append(ray.put(block))
metadata.append(BlockAccessor.for_block(block).get_metadata(None, None))
blk = (blocks[-1], metadata[-1])
ref_bundles.append(RefBundle((blk,), owns_blocks=True))
total_rows += block_size
block_list = BlockList(blocks, metadata, owned_by_consumer=True)

logical_plan = LogicalPlan(InputData(input_data=ref_bundles))
ds = Dataset(
ExecutionPlan(block_list, DatasetStats.TODO(), run_by_consumer=True),
0,
False,
logical_plan,
)

splits = ds.split(num_splits, equal=True)
Expand All @@ -111,7 +122,6 @@ def _test_equal_split_balanced(block_sizes, num_splits):


def test_equal_split_balanced_grid(ray_start_regular_shared):

# Tests balanced equal splitting over a grid of configurations.
# Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n
seed = int(time.time())
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used

Stage N Map: N/N blocks executed in T
Stage N Map(<lambda>): N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
Expand Down