Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] cleanup: use SortKey instead of mixed typing in aggregation #48697

Merged
merged 9 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 29 additions & 38 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -502,13 +503,13 @@ def sort_and_partition(

return find_partitions(table, boundaries, sort_key)

def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Block:
def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
"""Combine rows with the same key into an accumulator.

This assumes the block is already sorted by key in ascending order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstring contains key instead of sort_key. same with other methods

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SortKey type is kind of a misnomer. It's just the key(s) on which we happen to do things like groupby, sort, join, windowing etc.


Args:
key: A column name or list of column names.
sort_key: A column name or list of column names.
If this is ``None``, place all rows in a single group.

aggs: The aggregations to do.
Expand All @@ -519,18 +520,13 @@ def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Blo
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string, list of strings or None when aggregating "
"on Arrow blocks, but "
f"got: {type(key)}."
)
keys: List[str] = sort_key.get_columns()

def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
"""Creates an iterator over zero-copy group views."""
if key is None:
if not keys:
# Global aggregation consists of a single "group", so we short-circuit.
yield None, self.to_block()
yield tuple(), self.to_block()
return

start = end = 0
Expand All @@ -540,36 +536,33 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
try:
if next_row is None:
next_row = next(iter)
next_key = next_row[key]
while next_row[key] == next_key:
next_keys = next_row[keys]
while next_row[keys] == next_keys:
end += 1
try:
next_row = next(iter)
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end)
yield next_keys, self.slice(start, end)
start = end
except StopIteration:
break

builder = ArrowBlockBuilder()
for group_key, group_view in iter_groups():
for group_keys, group_view in iter_groups():
# Aggregate.
accumulators = [agg.init(group_key) for agg in aggs]
init_vals = group_keys
if len(group_keys) == 1:
init_vals = group_keys[0]

accumulators = [agg.init(init_vals) for agg in aggs]
for i in range(len(aggs)):
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)

# Build the row.
row = {}
if key is not None:
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

if keys:
for k, gk in zip(keys, group_keys):
row[k] = gk

Expand Down Expand Up @@ -608,7 +601,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List[Block],
key: Union[str, List[str]],
sort_key: "SortKey",
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple[Block, BlockMetadata]:
Expand All @@ -619,7 +612,7 @@ def aggregate_combined_blocks(

Args:
blocks: A list of partially combined and sorted blocks.
key: The column name of key or None for global aggregation.
sort_key: The column name of key or None for global aggregation.
aggs: The aggregations to do.
finalize: Whether to finalize the aggregation. This is used as an
optimization for cases where we repeatedly combine partially
Expand All @@ -633,13 +626,13 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()
keys = sort_key.get_columns()

keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: tuple(r[r._row.schema.names[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)
def key_fn(r):
if keys:
return tuple(r[keys])
else:
return (0,)

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
Expand All @@ -658,9 +651,7 @@ def aggregate_combined_blocks(
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.schema.names[: len(keys)] if key is not None else None
)
next_key_columns = keys

def gen():
nonlocal iter
Expand Down Expand Up @@ -699,9 +690,9 @@ def gen():
)
# Build the row.
row = {}
if key is not None:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key
if keys:
for col_name, next_key in zip(next_key_columns, next_keys):
row[col_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
70 changes: 32 additions & 38 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -415,14 +416,14 @@ def sort_and_partition(
return find_partitions(table, boundaries, sort_key)

def combine(
self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]
self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]
) -> "pandas.DataFrame":
"""Combine rows with the same key into an accumulator.

This assumes the block is already sorted by key in ascending order.

Args:
key: A column name or list of column names.
sort_key: A SortKey object which holds column names/keys.
If this is ``None``, place all rows in a single group.

aggs: The aggregations to do.
Expand All @@ -433,18 +434,14 @@ def combine(
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string, list of strings or None when aggregating "
"on Pandas blocks, but "
f"got: {type(key)}."
)
keys: List[str] = sort_key.get_columns()
pd = lazy_import_pandas()

def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
"""Creates an iterator over zero-copy group views."""
if key is None:
if not keys:
# Global aggregation consists of a single "group", so we short-circuit.
yield None, self.to_block()
yield tuple(), self.to_block()
return

start = end = 0
Expand All @@ -454,36 +451,34 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
try:
if next_row is None:
next_row = next(iter)
next_key = next_row[key]
while np.all(next_row[key] == next_key):
next_keys = next_row[keys]
while np.all(next_row[keys] == next_keys):
end += 1
try:
next_row = next(iter)
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end, copy=False)
if isinstance(next_keys, pd.Series):
next_keys = next_keys.values
yield next_keys, self.slice(start, end, copy=False)
start = end
except StopIteration:
break

builder = PandasBlockBuilder()
for group_key, group_view in iter_groups():
for group_keys, group_view in iter_groups():
# Aggregate.
accumulators = [agg.init(group_key) for agg in aggs]
init_vals = group_keys
if len(group_keys) == 1:
init_vals = group_keys[0]
accumulators = [agg.init(init_vals) for agg in aggs]
for i in range(len(aggs)):
accumulators[i] = aggs[i].accumulate_block(accumulators[i], group_view)

# Build the row.
row = {}
if key is not None:
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

if keys:
for k, gk in zip(keys, group_keys):
row[k] = gk

Expand Down Expand Up @@ -520,7 +515,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List["pandas.DataFrame"],
key: Union[str, List[str]],
sort_key: "SortKey",
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple["pandas.DataFrame", BlockMetadata]:
Expand All @@ -531,7 +526,7 @@ def aggregate_combined_blocks(

Args:
blocks: A list of partially combined and sorted blocks.
key: The column name of key or None for global aggregation.
sort_key: The column name of key or None for global aggregation.
aggs: The aggregations to do.
finalize: Whether to finalize the aggregation. This is used as an
optimization for cases where we repeatedly combine partially
Expand All @@ -545,12 +540,13 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()
keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: tuple(r[r._row.columns[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)
keys = sort_key.get_columns()

def key_fn(r):
if keys:
return tuple(r[keys])
else:
return (0,)

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")
Expand All @@ -569,9 +565,7 @@ def aggregate_combined_blocks(
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.columns[: len(keys)] if key is not None else None
)
next_key_columns = keys

def gen():
nonlocal iter
Expand Down Expand Up @@ -610,9 +604,9 @@ def gen():
)
# Build the row.
row = {}
if key is not None:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key
if keys:
for col_name, next_key in zip(next_key_columns, next_keys):
row[col_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
10 changes: 6 additions & 4 deletions python/ray/data/_internal/planner/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from ray.data._internal.execution.interfaces import (
AllToAllTransformFn,
Expand All @@ -22,7 +22,7 @@


def generate_aggregate_fn(
key: Optional[str],
key: Optional[Union[str, List[str]]],
aggs: List[AggregateFn],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not making this API accept SortKey as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically there is no need for the aggregate function to take a sortkey; we just happen to use it as an implementation detail (our aggregations are sort-based).

_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
) -> AllToAllTransformFn:
Expand All @@ -49,6 +49,8 @@ def fn(

num_mappers = len(blocks)

sort_key = SortKey(key)

if key is None:
num_outputs = 1
boundaries = []
Expand All @@ -60,12 +62,12 @@ def fn(
]
# Sample boundaries for aggregate key.
boundaries = SortTaskSpec.sample_boundaries(
blocks, SortKey(key), num_outputs, sample_bar
blocks, sort_key, num_outputs, sample_bar
)

agg_spec = SortAggregateTaskSpec(
boundaries=boundaries,
key=key,
key=sort_key,
aggs=aggs,
)
if DataContext.get_current().use_push_based_shuffle:
Expand Down
Loading
Loading