Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Fix unequal partitions when grouping by multiple keys #47924

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,34 +174,30 @@ def sample_boundaries(
# TODO(zhilong): Update sort sample bar before finished.
samples = sample_bar.fetch_until_complete(sample_results)
del sample_results
samples = [s for s in samples if len(s) > 0]
samples: List[Block] = [s for s in samples if len(s) > 0]
# The dataset is empty
if len(samples) == 0:
return [None] * (num_reducers - 1)

# Convert samples to a sorted list[tuple[...]] where each tuple represents a
# sample.
# TODO: Once we deprecate pandas blocks, we can avoid this conversion and
# directly sort the samples.
builder = DelegatingBlockBuilder()
for sample in samples:
builder.add_block(sample)
samples = builder.build()

sample_dict = BlockAccessor.for_block(samples).to_numpy(columns=columns)
# Compute sorted indices of the samples. In np.lexsort last key is the
# primary key hence have to reverse the order.
indices = np.lexsort(list(reversed(list(sample_dict.values()))))
# Sort each column by indices, and calculate q-ths quantile items.
# Ignore the 1st item as it's not required for the boundary
for k, v in sample_dict.items():
sorted_v = v[indices]
sample_dict[k] = list(
np.quantile(
sorted_v, np.linspace(0, 1, num_reducers), interpolation="nearest"
)[1:]
)
# Return the list of boundaries as tuples
# of a form (col1_value, col2_value, ...)
return [
tuple(sample_dict[k][i] for k in sample_dict)
for i in range(num_reducers - 1)
samples_table = builder.build()
samples_dict = BlockAccessor.for_block(samples_table).to_numpy(columns=columns)
# This zip does the transposition from list of column values to list of tuples.
samples_list = sorted(zip(*samples_dict.values()))

# Each boundary corresponds to a quantile of the data.
quantile_indices = [
int(q * (len(samples_list) - 1))
for q in np.linspace(0, 1, num_reducers + 1)
]
# Exclude the first and last quantiles because they're 0 and 1.
return [samples_list[i] for i in quantile_indices[1:-1]]


def _sample_block(block: Block, n_samples: int, sort_key: SortKey) -> Block:
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ def test_sort_with_specified_boundaries(ray_start_regular, descending, boundarie
assert np.all(block["id"] == expected_block)


def test_sort_multiple_keys_produces_equally_sized_blocks(ray_start_regular):
# Test for https://github.com/ray-project/ray/issues/45303.
ds = ray.data.from_items(
[{"a": i, "b": j} for i in range(2) for j in range(5)], override_num_blocks=5
)

ds_sorted = ds.sort(["a", "b"])

num_rows_per_block = [
bundle.num_rows() for bundle in ds_sorted.iter_internal_ref_bundles()
]
# Number of output blocks should be equal to the number of input blocks.
assert len(num_rows_per_block) == 5, len(num_rows_per_block)
# Ideally we should have 10 rows / 5 blocks = 2 rows per block, but to make this
# test less fragile we allow for a small deviation.
assert all(
1 <= num_rows <= 3 for num_rows in num_rows_per_block
), num_rows_per_block


def test_sort_simple(ray_start_regular, use_push_based_shuffle):
num_items = 100
parallelism = 4
Expand Down
Loading