Skip to content

Commit

Permalink
[Data] Optimizing the multi column groupby
Browse files Browse the repository at this point in the history
Changed from a custom class implementation to a purely numpy implementation

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
  • Loading branch information
wingkitlee0 committed Jun 20, 2024
1 parent 231a013 commit 8e7f055
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 54 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,11 @@ py_test(
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_group_boundaries",
size = "small",
srcs = ["tests/test_group_boundaries.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)
29 changes: 29 additions & 0 deletions python/ray/data/_internal/boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, Union

import numpy as np


def get_key_boundaries(
keys: Union[np.ndarray, Dict[str, np.ndarray]], include_first: bool = True
) -> np.ndarray:
"""Compute block boundaries based on the key(s), that is, a list of
starting indices of each group and a end index of the last group.
Args:
keys: numpy arrays of the group key(s).
include_first: Whether to include the first index (0).
Returns:
A list of starting indices of each group. The first entry is 0 and
the last entry is ``len(array)``.
"""

if isinstance(keys, dict):
# For multiple keys, we create a numpy record array
dtype = [(k, v.dtype) for k, v in keys.items()]
keys = np.array(list(zip(*keys.values())), dtype=dtype)

if include_first:
return np.hstack([[0], np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
else:
return np.hstack([np.where(keys[1:] != keys[:-1])[0] + 1, [len(keys)]])
4 changes: 2 additions & 2 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def slice(self, start: int, end: int, copy: bool) -> Block:
"""Return a slice of this block.
Args:
start: The starting index of the slice.
end: The ending index of the slice.
start: The starting index of the slice (inclusive).
end: The ending index of the slice (exclusive).
copy: Whether to perform a data copy for the slice.
Returns:
Expand Down
65 changes: 13 additions & 52 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from ray.data._internal.boundaries import get_key_boundaries
from ray.data._internal.compute import ComputeStrategy
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
from ray.data.aggregate import AggregateFn, Count, Max, Mean, Min, Std, Sum
from ray.data.block import BlockAccessor, CallableClass, UserDefinedFunction
from ray.data.dataset import DataBatch, Dataset
from ray.data.block import BlockAccessor, CallableClass, DataBatch, UserDefinedFunction
from ray.data.dataset import Dataset
from ray.util.annotations import PublicAPI


class _MultiColumnSortedKey:
"""Represents a tuple of group keys with a ``__lt__`` method
This is a simple implementation to support multi-column groupby.
While a 1D array of tuples suffices to maintain the lexicographical
sorted order, a comparison method is also needed in ``np.searchsorted``
(for computing the group key boundaries).
"""

__slots__ = ("data",)

def __init__(self, *args):
self.data = tuple(args)

def __lt__(self, obj: "_MultiColumnSortedKey") -> bool:
return self.data < obj.data

def __repr__(self) -> str:
"""Print as T(1, 2)"""
return "T" + self.data.__repr__()


@PublicAPI
class GroupedData:
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
Expand All @@ -41,7 +20,7 @@ class GroupedData:
def __init__(
self,
dataset: Dataset,
key: Union[str, List[str]],
key: Optional[Union[str, List[str]]],
):
"""Construct a dataset grouped by key (internal API).
Expand Down Expand Up @@ -192,47 +171,29 @@ def map_groups(
else:
sorted_ds = self._dataset.repartition(1)

def get_key_boundaries(block_accessor: BlockAccessor) -> List[int]:
"""Compute block boundaries based on the key(s)"""

import numpy as np

# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)

if isinstance(keys, dict):
# For multiple keys, we generate a separate tuple column
convert_to_multi_column_sorted_key = np.vectorize(_MultiColumnSortedKey)
keys: np.ndarray = convert_to_multi_column_sorted_key(*keys.values())

boundaries = []
start = 0
while start < keys.size:
end = start + np.searchsorted(keys[start:], keys[start], side="right")
boundaries.append(end)
start = end
return boundaries

# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def apply_udf_to_groups(udf, batch, *args, **kwargs):
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)

# Get the list of boundaries including first start and last end indices
if self._key:
boundaries = get_key_boundaries(block_accessor)
# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)
boundaries = get_key_boundaries(keys)
else:
boundaries = [block_accessor.num_rows()]
start = 0
for end in boundaries:
group_block = block_accessor.slice(start, end)
boundaries = [0, block_accessor.num_rows()]

for start, end in zip(boundaries[:-1], boundaries[1:]):
group_block = block_accessor.slice(start, end, copy=False)
group_block_accessor = BlockAccessor.for_block(group_block)
# Convert block of each group to batch format here, because the
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = udf(group_batch, *args, **kwargs)
yield applied
start = end

if isinstance(fn, CallableClass):

Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_group_boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np

from ray.data._internal.boundaries import get_key_boundaries


def test_groupby_map_groups_get_key_boundaries():
indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array([1, 1, 2, 2, 3, 4]),
}
)

assert list(indices) == [0, 2, 4, 5, 6]

indices = get_key_boundaries(
keys={
"x": np.array([1, 1, 2, 2, 3, 3]),
"y": np.array(["a", "b", "a", "a", "b", "b"]),
}
)

assert list(indices) == [0, 1, 2, 4, 6]

indices = get_key_boundaries(np.array([1, 1, 2, 2, 3, 3]))

assert list(indices) == [0, 2, 4, 6]

0 comments on commit 8e7f055

Please sign in to comment.