Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Optimizing the multi column groupby #45667

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,11 @@ py_test(
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_group_boundaries",
size = "small",
srcs = ["tests/test_group_boundaries.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)
29 changes: 29 additions & 0 deletions python/ray/data/_internal/boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, Union

import numpy as np


def get_key_boundaries(
keys: Union[np.ndarray, Dict[str, np.ndarray]], include_first: bool = True
) -> np.ndarray:
"""Compute block boundaries based on the key(s), that is, a list of
starting indices of each group and a end index of the last group.

Args:
keys: numpy arrays of the group key(s).
include_first: Whether to include the first index (0).

Returns:
A list of starting indices of each group. The first entry is 0 and
the last entry is ``len(array)``.
"""

if isinstance(keys, dict):
# For multiple keys, we create a numpy record array
dtype = [(k, v.dtype) for k, v in keys.items()]
keys = np.array(list(zip(*keys.values())), dtype=dtype)

if include_first:
return np.hstack([[0], np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
else:
return np.hstack([np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
4 changes: 2 additions & 2 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def slice(self, start: int, end: int, copy: bool) -> Block:
"""Return a slice of this block.

Args:
start: The starting index of the slice.
end: The ending index of the slice.
start: The starting index of the slice (inclusive).
end: The ending index of the slice (exclusive).
copy: Whether to perform a data copy for the slice.

Returns:
Expand Down
61 changes: 11 additions & 50 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from ray.data._internal.aggregate import Count, Max, Mean, Min, Std, Sum
from ray.data._internal.boundaries import get_key_boundaries
from ray.data._internal.compute import ComputeStrategy
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
Expand All @@ -13,28 +14,6 @@
FA_API_GROUP = "Function Application"


class _MultiColumnSortedKey:
"""Represents a tuple of group keys with a ``__lt__`` method

This is a simple implementation to support multi-column groupby.
While a 1D array of tuples suffices to maintain the lexicographical
sorted order, a comparison method is also needed in ``np.searchsorted``
(for computing the group key boundaries).
"""

__slots__ = ("data",)

def __init__(self, *args):
self.data = tuple(args)

def __lt__(self, obj: "_MultiColumnSortedKey") -> bool:
return self.data < obj.data

def __repr__(self) -> str:
"""Print as T(1, 2)"""
return "T" + self.data.__repr__()


class GroupedData:
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.

Expand All @@ -44,7 +23,7 @@ class GroupedData:
def __init__(
self,
dataset: Dataset,
key: Union[str, List[str]],
key: Optional[Union[str, List[str]]],
):
"""Construct a dataset grouped by key (internal API).

Expand Down Expand Up @@ -197,47 +176,29 @@ def map_groups(
else:
sorted_ds = self._dataset.repartition(1)

def get_key_boundaries(block_accessor: BlockAccessor) -> List[int]:
"""Compute block boundaries based on the key(s)"""

import numpy as np

# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)

if isinstance(keys, dict):
# For multiple keys, we generate a separate tuple column
convert_to_multi_column_sorted_key = np.vectorize(_MultiColumnSortedKey)
keys: np.ndarray = convert_to_multi_column_sorted_key(*keys.values())

boundaries = []
start = 0
while start < keys.size:
end = start + np.searchsorted(keys[start:], keys[start], side="right")
boundaries.append(end)
start = end
return boundaries

# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def apply_udf_to_groups(udf, batch, *args, **kwargs):
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)

# Get the list of boundaries including first start and last end indices
if self._key:
boundaries = get_key_boundaries(block_accessor)
# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)
boundaries = get_key_boundaries(keys)
else:
boundaries = [block_accessor.num_rows()]
start = 0
for end in boundaries:
group_block = block_accessor.slice(start, end)
boundaries = [0, block_accessor.num_rows()]

for start, end in zip(boundaries[:-1], boundaries[1:]):
group_block = block_accessor.slice(start, end, copy=False)
group_block_accessor = BlockAccessor.for_block(group_block)
# Convert block of each group to batch format here, because the
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = udf(group_batch, *args, **kwargs)
yield applied
start = end

if isinstance(fn, CallableClass):

Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_group_boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np

from ray.data._internal.boundaries import get_key_boundaries


def test_groupby_map_groups_get_key_boundaries():
indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array([1, 1, 2, 2, 3, 4]),
}
)

assert list(indices) == [0, 2, 4, 5, 6]

indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array(["a", "b", "a", "a", "b", "b"]),
}
)

assert list(indices) == [0, 1, 2, 4, 6]

indices = get_key_boundaries(np.array([1, 1, 2, 2, 3, 3]))

assert list(indices) == [0, 2, 4, 6]