Skip to content

Commit

Permalink
[Data] Optimizing the multi column groupby
Browse files Browse the repository at this point in the history
Changed from a custom class implementation to a purely numpy implementation

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
  • Loading branch information
wingkitlee0 committed Sep 29, 2024
1 parent e07594e commit b6076d6
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 52 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,11 @@ py_test(
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_group_boundaries",
size = "small",
srcs = ["tests/test_group_boundaries.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)
29 changes: 29 additions & 0 deletions python/ray/data/_internal/boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, Union

import numpy as np


def get_key_boundaries(
keys: Union[np.ndarray, Dict[str, np.ndarray]], include_first: bool = True
) -> np.ndarray:
"""Compute block boundaries based on the key(s), that is, a list of
starting indices of each group and a end index of the last group.
Args:
keys: numpy arrays of the group key(s).
include_first: Whether to include the first index (0).
Returns:
A list of starting indices of each group. The first entry is 0 and
the last entry is ``len(array)``.
"""

if isinstance(keys, dict):
# For multiple keys, we create a numpy record array
dtype = [(k, v.dtype) for k, v in keys.items()]
keys = np.array(list(zip(*keys.values())), dtype=dtype)

if include_first:
return np.hstack([[0], np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
else:
return np.hstack([np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
4 changes: 2 additions & 2 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def slice(self, start: int, end: int, copy: bool) -> Block:
"""Return a slice of this block.
Args:
start: The starting index of the slice.
end: The ending index of the slice.
start: The starting index of the slice (inclusive).
end: The ending index of the slice (exclusive).
copy: Whether to perform a data copy for the slice.
Returns:
Expand Down
61 changes: 11 additions & 50 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from ray.data._internal.aggregate import Count, Max, Mean, Min, Std, Sum
from ray.data._internal.boundaries import get_key_boundaries
from ray.data._internal.compute import ComputeStrategy
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
Expand All @@ -13,28 +14,6 @@
FA_API_GROUP = "Function Application"


class _MultiColumnSortedKey:
"""Represents a tuple of group keys with a ``__lt__`` method
This is a simple implementation to support multi-column groupby.
While a 1D array of tuples suffices to maintain the lexicographical
sorted order, a comparison method is also needed in ``np.searchsorted``
(for computing the group key boundaries).
"""

__slots__ = ("data",)

def __init__(self, *args):
self.data = tuple(args)

def __lt__(self, obj: "_MultiColumnSortedKey") -> bool:
return self.data < obj.data

def __repr__(self) -> str:
"""Print as T(1, 2)"""
return "T" + self.data.__repr__()


class GroupedData:
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
Expand All @@ -44,7 +23,7 @@ class GroupedData:
def __init__(
self,
dataset: Dataset,
key: Union[str, List[str]],
key: Optional[Union[str, List[str]]],
):
"""Construct a dataset grouped by key (internal API).
Expand Down Expand Up @@ -197,47 +176,29 @@ def map_groups(
else:
sorted_ds = self._dataset.repartition(1)

def get_key_boundaries(block_accessor: BlockAccessor) -> List[int]:
"""Compute block boundaries based on the key(s)"""

import numpy as np

# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)

if isinstance(keys, dict):
# For multiple keys, we generate a separate tuple column
convert_to_multi_column_sorted_key = np.vectorize(_MultiColumnSortedKey)
keys: np.ndarray = convert_to_multi_column_sorted_key(*keys.values())

boundaries = []
start = 0
while start < keys.size:
end = start + np.searchsorted(keys[start:], keys[start], side="right")
boundaries.append(end)
start = end
return boundaries

# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def apply_udf_to_groups(udf, batch, *args, **kwargs):
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)

# Get the list of boundaries including first start and last end indices
if self._key:
boundaries = get_key_boundaries(block_accessor)
# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)
boundaries = get_key_boundaries(keys)
else:
boundaries = [block_accessor.num_rows()]
start = 0
for end in boundaries:
group_block = block_accessor.slice(start, end)
boundaries = [0, block_accessor.num_rows()]

for start, end in zip(boundaries[:-1], boundaries[1:]):
group_block = block_accessor.slice(start, end, copy=False)
group_block_accessor = BlockAccessor.for_block(group_block)
# Convert block of each group to batch format here, because the
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = udf(group_batch, *args, **kwargs)
yield applied
start = end

if isinstance(fn, CallableClass):

Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_group_boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np

from ray.data._internal.boundaries import get_key_boundaries


def test_groupby_map_groups_get_key_boundaries():
indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array([1, 1, 2, 2, 3, 4]),
}
)

assert list(indices) == [0, 2, 4, 5, 6]

indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array(["a", "b", "a", "a", "b", "b"]),
}
)

assert list(indices) == [0, 1, 2, 4, 6]

indices = get_key_boundaries(np.array([1, 1, 2, 2, 3, 3]))

assert list(indices) == [0, 2, 4, 6]

0 comments on commit b6076d6

Please sign in to comment.