Skip to content

Commit

Permalink
Add support for more generic Hashable args
Browse files Browse the repository at this point in the history
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
  • Loading branch information
alexeykudinkin committed Sep 24, 2024
1 parent 17bdc08 commit 277e0a0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 18 deletions.
4 changes: 2 additions & 2 deletions python/ray/data/_internal/remote_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, Hashable, List

import ray

Expand Down Expand Up @@ -50,7 +50,7 @@ def _make_hashable(obj):
elif isinstance(obj, Dict):
converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()]
return tuple(sorted(converted, key=lambda t: t[0]))
elif isinstance(obj, (bool, int, float, str, bytes, type(None))):
elif isinstance(obj, Hashable):
return obj
else:
raise ValueError(f"Type {type(obj)} is not hashable")
Expand Down
19 changes: 3 additions & 16 deletions python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from typing_extensions import Hashable

import ray
from ray.data._internal.datasource.parquet_datasource import ParquetDatasource
Expand Down Expand Up @@ -45,6 +46,7 @@ def test_make_hashable():
},
"list": list(range(10)),
"tuple": tuple(range(3)),
"type": Hashable,
}

hashable_args = _make_hashable(valid_args)
Expand All @@ -57,6 +59,7 @@ def test_make_hashable():
("list", (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
("str", "foo"),
("tuple", (0, 1, 2)),
("type", Hashable),
)
)

Expand All @@ -70,22 +73,6 @@ def test_make_hashable():
str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'"
)

# Invalid case # 2: can't use anything but dict, list, tuple or primitive types
class Foo:
bar: 0

invalid_args = {
0: Foo(),
}

with pytest.raises(ValueError) as exc_info:
_make_hashable(invalid_args)

assert (
str(exc_info.value)
== "Type <class 'test_util.test_make_hashable.<locals>.Foo'> is not hashable"
)


def test_check_pyarrow_version_bounds(unsupported_pyarrow_version):
# Test that pyarrow versions outside of the defined bounds cause an ImportError to
Expand Down

0 comments on commit 277e0a0

Please sign in to comment.