diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index bcb4b704aefb9..edeea0639464a 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -174,34 +174,30 @@ def sample_boundaries( # TODO(zhilong): Update sort sample bar before finished. samples = sample_bar.fetch_until_complete(sample_results) del sample_results - samples = [s for s in samples if len(s) > 0] + samples: List[Block] = [s for s in samples if len(s) > 0] # The dataset is empty if len(samples) == 0: return [None] * (num_reducers - 1) + + # Convert samples to a sorted list[tuple[...]] where each tuple represents a + # sample. + # TODO: Once we deprecate pandas blocks, we can avoid this conversion and + # directly sort the samples. builder = DelegatingBlockBuilder() for sample in samples: builder.add_block(sample) - samples = builder.build() - - sample_dict = BlockAccessor.for_block(samples).to_numpy(columns=columns) - # Compute sorted indices of the samples. In np.lexsort last key is the - # primary key hence have to reverse the order. - indices = np.lexsort(list(reversed(list(sample_dict.values())))) - # Sort each column by indices, and calculate q-ths quantile items. - # Ignore the 1st item as it's not required for the boundary - for k, v in sample_dict.items(): - sorted_v = v[indices] - sample_dict[k] = list( - np.quantile( - sorted_v, np.linspace(0, 1, num_reducers), interpolation="nearest" - )[1:] - ) - # Return the list of boundaries as tuples - # of a form (col1_value, col2_value, ...) - return [ - tuple(sample_dict[k][i] for k in sample_dict) - for i in range(num_reducers - 1) + samples_table = builder.build() + samples_dict = BlockAccessor.for_block(samples_table).to_numpy(columns=columns) + # This zip does the transposition from list of column values to list of tuples. + samples_list = sorted(zip(*samples_dict.values())) + + # Each boundary corresponds to a quantile of the data. + quantile_indices = [ + int(q * (len(samples_list) - 1)) + for q in np.linspace(0, 1, num_reducers + 1) ] + # Exclude the first and last quantiles because they're 0 and 1. + return [samples_list[i] for i in quantile_indices[1:-1]] def _sample_block(block: Block, n_samples: int, sort_key: SortKey) -> Block: diff --git a/python/ray/data/tests/test_sort.py b/python/ray/data/tests/test_sort.py index 2246723553ab0..982cf99409d4e 100644 --- a/python/ray/data/tests/test_sort.py +++ b/python/ray/data/tests/test_sort.py @@ -51,6 +51,26 @@ def test_sort_with_specified_boundaries(ray_start_regular, descending, boundarie assert np.all(block["id"] == expected_block) +def test_sort_multiple_keys_produces_equally_sized_blocks(ray_start_regular): + # Test for https://github.com/ray-project/ray/issues/45303. + ds = ray.data.from_items( + [{"a": i, "b": j} for i in range(2) for j in range(5)], override_num_blocks=5 + ) + + ds_sorted = ds.sort(["a", "b"]) + + num_rows_per_block = [ + bundle.num_rows() for bundle in ds_sorted.iter_internal_ref_bundles() + ] + # Number of output blocks should be equal to the number of input blocks. + assert len(num_rows_per_block) == 5, len(num_rows_per_block) + # Ideally we should have 10 rows / 5 blocks = 2 rows per block, but to make this + # test less fragile we allow for a small deviation. + assert all( + 1 <= num_rows <= 3 for num_rows in num_rows_per_block + ), num_rows_per_block + + def test_sort_simple(ray_start_regular, use_push_based_shuffle): num_items = 100 parallelism = 4