Skip to content

Commit

Permalink
[datasets] Unify Datasets primitives on a common shuffle op (#23614)
Browse files Browse the repository at this point in the history
Currently Datasets primitives repartition, groupby, sort, and random_shuffle all use different internal shuffle implementations. This PR unifies them on a single internal ShuffleOp class. This class exposes static methods for map and reduce which must be implemented by the specific higher-level primitive. Then the ShuffleOp.execute method implements a simple pull-based shuffle by submitting one map task per input block and one reduce task per output block.

Closes #23593.
  • Loading branch information
stephanie-wang authored Apr 5, 2022
1 parent dc994db commit 9813f2c
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 271 deletions.
19 changes: 10 additions & 9 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from ray.data.impl.compute import cache_wrapper, CallableClass, ComputeStrategy
from ray.data.impl.output_buffer import BlockOutputBuffer
from ray.data.impl.progress_bar import ProgressBar
from ray.data.impl.shuffle import simple_shuffle
from ray.data.impl.shuffle import ShufflePartitionOp
from ray.data.impl.fast_repartition import fast_repartition
from ray.data.impl.sort import sort_impl
from ray.data.impl.block_list import BlockList
Expand Down Expand Up @@ -497,10 +497,11 @@ def do_shuffle(
block_list.clear()
else:
blocks = block_list
return simple_shuffle(
shuffle_op = ShufflePartitionOp(block_udf, random_shuffle=False)
return shuffle_op.execute(
blocks,
block_udf,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
Expand Down Expand Up @@ -566,16 +567,16 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
block_list.clear()
else:
blocks = block_list
new_blocks, stage_info = simple_shuffle(
random_shuffle_op = ShufflePartitionOp(
block_udf, random_shuffle=True, random_seed=seed
)
return random_shuffle_op.execute(
blocks,
block_udf,
num_blocks,
random_shuffle=True,
random_seed=seed,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
return new_blocks, stage_info

plan = self._plan.with_stage(
AllToAllStage(
Expand Down Expand Up @@ -1449,7 +1450,7 @@ def do_sort(block_list, clear_input_blocks: bool, *_):
_validate_key_fn(self, subkey)
else:
_validate_key_fn(self, key)
return sort_impl(blocks, key, descending)
return sort_impl(blocks, clear_input_blocks, key, descending)

plan = self._plan.with_stage(AllToAllStage("sort", None, do_sort))
return Dataset(plan, self._epoch, self._lazy)
Expand Down
113 changes: 44 additions & 69 deletions python/ray/data/grouped_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,52 @@
from typing import Any, Union, Generic, Tuple, List, Callable
import numpy as np
import ray
from ray.util.annotations import PublicAPI
from ray.data.dataset import Dataset
from ray.data.dataset import BatchType
from ray.data.impl import sort
from ray.data.aggregate import AggregateFn, Count, Sum, Max, Min, Mean, Std
from ray.data.block import BlockExecStats, KeyFn
from ray.data.impl.plan import AllToAllStage
from ray.data.impl.block_list import BlockList
from ray.data.impl.compute import CallableClass, ComputeStrategy
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.progress_bar import ProgressBar
from ray.data.impl.shuffle import ShuffleOp
from ray.data.block import Block, BlockAccessor, BlockMetadata, T, U, KeyType


class GroupbyOp(ShuffleOp):
@staticmethod
def map(
idx: int,
block: Block,
output_num_blocks: int,
boundaries: List[KeyType],
key: KeyFn,
aggs: Tuple[AggregateFn],
) -> List[Union[BlockMetadata, Block]]:
"""Partition the block and combine rows with the same key."""
stats = BlockExecStats.builder()
if key is None:
partitions = [block]
else:
partitions = BlockAccessor.for_block(block).sort_and_partition(
boundaries,
[(key, "ascending")] if isinstance(key, str) else key,
descending=False,
)
parts = [BlockAccessor.for_block(p).combine(key, aggs) for p in partitions]
meta = BlockAccessor.for_block(block).get_metadata(
input_files=None, exec_stats=stats.build()
)
return [meta] + parts

@staticmethod
def reduce(
key: KeyFn, aggs: Tuple[AggregateFn], *mapper_outputs: List[Block]
) -> (Block, BlockMetadata):
"""Aggregate sorted and partially combined blocks."""
return BlockAccessor.for_block(mapper_outputs[0]).aggregate_combined_blocks(
list(mapper_outputs), key, aggs
)


@PublicAPI
class GroupedDataset(Generic[T]):
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
Expand Down Expand Up @@ -86,42 +118,14 @@ def do_agg(blocks, clear_input_blocks: bool, *_):
else self._key,
num_reducers,
)

partition_and_combine_block = cached_remote_fn(
_partition_and_combine_block
).options(num_returns=num_reducers + 1)
aggregate_combined_blocks = cached_remote_fn(
_aggregate_combined_blocks, num_returns=2
shuffle_op = GroupbyOp(
map_args=[boundaries, self._key, aggs], reduce_args=[self._key, aggs]
)
return shuffle_op.execute(
blocks,
num_reducers,
clear_input_blocks,
)

map_results = np.empty((num_mappers, num_reducers), dtype=object)
map_meta = []
for i, block in enumerate(blocks.get_blocks()):
results = partition_and_combine_block.remote(
block, boundaries, self._key, aggs
)
map_results[i, :] = results[:-1]
map_meta.append(results[-1])
map_bar = ProgressBar("GroupBy Map", len(map_results))
map_bar.block_until_complete(map_meta)
stage_info["map"] = ray.get(map_meta)
map_bar.close()

blocks = []
metadata = []
for j in range(num_reducers):
block, meta = aggregate_combined_blocks.remote(
num_reducers, self._key, aggs, *map_results[:, j].tolist()
)
blocks.append(block)
metadata.append(meta)
reduce_bar = ProgressBar("GroupBy Reduce", len(blocks))
reduce_bar.block_until_complete(blocks)
reduce_bar.close()

metadata = ray.get(metadata)
stage_info["reduce"] = metadata
return BlockList(blocks, metadata), stage_info

plan = self._dataset._plan.with_stage(AllToAllStage("aggregate", None, do_agg))
return Dataset(
Expand Down Expand Up @@ -603,32 +607,3 @@ def std(
If groupby key is ``None`` then the key part of return is omitted.
"""
return self._aggregate_on(Std, on, ignore_nulls, ddof=ddof)


def _partition_and_combine_block(
block: Block[T], boundaries: List[KeyType], key: KeyFn, aggs: Tuple[AggregateFn]
) -> List[Union[Block, BlockMetadata]]:
"""Partition the block and combine rows with the same key."""
stats = BlockExecStats.builder()
if key is None:
partitions = [block]
else:
partitions = BlockAccessor.for_block(block).sort_and_partition(
boundaries,
[(key, "ascending")] if isinstance(key, str) else key,
descending=False,
)
parts = [BlockAccessor.for_block(p).combine(key, aggs) for p in partitions]
meta = BlockAccessor.for_block(block).get_metadata(
input_files=None, exec_stats=stats.build()
)
return parts + [meta]


def _aggregate_combined_blocks(
num_reducers: int, key: KeyFn, aggs: Tuple[AggregateFn], *blocks: Tuple[Block, ...]
) -> Tuple[Block[U], BlockMetadata]:
"""Aggregate sorted and partially combined blocks."""
return BlockAccessor.for_block(blocks[0]).aggregate_combined_blocks(
list(blocks), key, aggs
)
6 changes: 3 additions & 3 deletions python/ray/data/impl/fast_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.data.impl.plan import ExecutionPlan
from ray.data.impl.progress_bar import ProgressBar
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.shuffle import _shuffle_reduce
from ray.data.impl.shuffle import ShufflePartitionOp
from ray.data.impl.stats import DatasetStats


Expand All @@ -32,10 +32,10 @@ def fast_repartition(blocks, num_blocks):
# consider combining the split and coalesce tasks as an optimization.

# Coalesce each split into a single block.
reduce_task = cached_remote_fn(_shuffle_reduce).options(num_returns=2)
reduce_task = cached_remote_fn(ShufflePartitionOp.reduce).options(num_returns=2)
reduce_bar = ProgressBar("Repartition", position=0, total=len(splits))
reduce_out = [
reduce_task.remote(*s.get_internal_block_refs())
reduce_task.remote(False, None, *s.get_internal_block_refs())
for s in splits
if s.num_blocks() > 0
]
Expand Down
Loading

0 comments on commit 9813f2c

Please sign in to comment.