Skip to content

Commit

Permalink
[Data] Raise ValueError when the data sort key is None (ray-project#4…
Browse files Browse the repository at this point in the history
…8969)

Closes: ray-project#48926
ray-project#48927
<!-- Please give a short summary of the change and the problem this
solves. -->

Signed-off-by: Superskyyy <yihaochen@apache.org>
Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
Signed-off-by: Hao Chen <chenh1024@gmail.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Hao Chen <chenh1024@gmail.com>
Signed-off-by: Connor Sanders <connor@elastiflow.com>
  • Loading branch information
3 people authored and jecsand838 committed Dec 4, 2024
1 parent cd6c2fa commit 1cf6452
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,11 +2435,12 @@ def std(
@PublicAPI(api_group=SSR_API_GROUP)
def sort(
self,
key: Union[str, List[str], None] = None,
key: Union[str, List[str]],
descending: Union[bool, List[bool]] = False,
boundaries: List[Union[int, float]] = None,
) -> "Dataset":
"""Sort the dataset by the specified key column or key function.
The `key` parameter must be specified (i.e., it cannot be `None`).
.. note::
The `descending` parameter must be a boolean, or a list of booleans.
Expand Down Expand Up @@ -2488,7 +2489,12 @@ def sort(
Returns:
A new, sorted :class:`Dataset`.
Raises:
``ValueError``: if the sort key is None.
"""
if key is None:
raise ValueError("The 'key' parameter cannot be None for sorting.")
sort_key = SortKey(key, descending, boundaries)
plan = self._plan.copy()
op = Sort(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_optimize_lazy_reuse_base_data(


def test_require_preserve_order(ray_start_regular_shared):
ds = ray.data.range(100).map_batches(lambda x: x).sort()
ds = ray.data.range(100).map_batches(lambda x: x).sort("id")
assert ds._plan.require_preserve_order()
ds2 = ray.data.range(100).map_batches(lambda x: x).zip(ds)
assert ds2._plan.require_preserve_order()
Expand Down

0 comments on commit 1cf6452

Please sign in to comment.