From bece986417aecf9805f34b6b54f4fb0813cf67e3 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 14:33:43 -0700 Subject: [PATCH 01/15] Fixed `cached_remote_fn` to avoid mixing up Ray's remote method instantiation for callers providing different arguments sets Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/remote_fn.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/remote_fn.py b/python/ray/data/_internal/remote_fn.py index fe0fd5ee83709..f5e47d05def71 100644 --- a/python/ray/data/_internal/remote_fn.py +++ b/python/ray/data/_internal/remote_fn.py @@ -16,7 +16,10 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: and should be set with ``options`` instead: ``cached_remote_fn(fn, **static_args).options(**dynamic_args)``. """ - if fn not in CACHED_FUNCTIONS: + + args_hash = hash(ray_remote_args) + + if (fn, args_hash) not in CACHED_FUNCTIONS: default_ray_remote_args = { # Use the default scheduling strategy for all tasks so that we will # not inherit a placement group from the caller, if there is one. @@ -27,8 +30,12 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: } ray_remote_args = {**default_ray_remote_args, **ray_remote_args} _add_system_error_to_retry_exceptions(ray_remote_args) - CACHED_FUNCTIONS[fn] = ray.remote(**ray_remote_args)(fn) - return CACHED_FUNCTIONS[fn] + + # NOTE: Hash of the passed in arguments guarantees that we're caching + # complete instantiation of the Ray's remote method + CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn) + + return CACHED_FUNCTIONS[(fn, args_hash)] def _add_system_error_to_retry_exceptions(ray_remote_args) -> None: From 753be4095df0ea9801729df229e206e986c31e92 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 14:41:17 -0700 Subject: [PATCH 02/15] Make arg-list hashable Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/remote_fn.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/remote_fn.py b/python/ray/data/_internal/remote_fn.py index f5e47d05def71..e05616393eca2 100644 --- a/python/ray/data/_internal/remote_fn.py +++ b/python/ray/data/_internal/remote_fn.py @@ -12,12 +12,20 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: (ray imports ray.data in order to allow ``ray.data.read_foo()`` to work, which means ray.remote cannot be used top-level in ray.data). - Note: Dynamic arguments should not be passed in directly, + NOTE: Dynamic arguments should not be passed in directly, and should be set with ``options`` instead: ``cached_remote_fn(fn, **static_args).options(**dynamic_args)``. """ - args_hash = hash(ray_remote_args) + # NOTE: Hash of the passed in arguments guarantees that we're caching + # complete instantiation of the Ray's remote method + # + # To compute the hash of passed in arguments and make sure it's deterministic + # - Sort all KV-pairs by the keys + # - Convert sorted list into tuple + # - Compute hash of the resulting tuple + arg_pairs = tuple(sorted(list(ray_remote_args.items()), key=lambda t: t[0])) + args_hash = hash(arg_pairs) if (fn, args_hash) not in CACHED_FUNCTIONS: default_ray_remote_args = { @@ -31,8 +39,7 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: ray_remote_args = {**default_ray_remote_args, **ray_remote_args} _add_system_error_to_retry_exceptions(ray_remote_args) - # NOTE: Hash of the passed in arguments guarantees that we're caching - # complete instantiation of the Ray's remote method + CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn) return CACHED_FUNCTIONS[(fn, args_hash)] From b4b30f0ed91c9d3989f917cc2263dbaa58ed057e Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 14:41:29 -0700 Subject: [PATCH 03/15] Added test Signed-off-by: Alexey Kudinkin --- python/ray/data/tests/test_util.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index d90f7ed395db9..cbc5ed9ac4d3c 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -11,6 +11,7 @@ trace_allocation, trace_deallocation, ) +from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import ( _check_pyarrow_version, _split_list, @@ -19,6 +20,20 @@ from ray.data.tests.conftest import * # noqa: F401, F403 +def test_cached_remote_fn(): + def foo(): + pass + + cpu_only_foo = cached_remote_fn(foo, num_cpus=1) + cached_cpu_only_foo = cached_remote_fn(foo, num_cpus=1) + + assert cpu_only_foo == cached_cpu_only_foo + + gpu_only_foo = cached_remote_fn(foo, num_gpus=1) + + assert cpu_only_foo != gpu_only_foo + + def test_check_pyarrow_version_bounds(unsupported_pyarrow_version): # Test that pyarrow versions outside of the defined bounds cause an ImportError to # be raised. From 6c67e74a9bdfea1fdbbe22dc0ec4d649d68a83b6 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 14:46:37 -0700 Subject: [PATCH 04/15] Pass in static Ray remote args for mapping tasks Signed-off-by: Alexey Kudinkin --- .../execution/operators/task_pool_map_operator.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index d4febdf8cee04..f01fd575bea3f 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -63,9 +63,17 @@ def __init__( ) self._concurrency = concurrency + # NOTE: Unlike static Ray remote args, dynamic arguments extracted from the blocks + # themselves are going to be passed inside `fn.options(...)` invocation + ray_remote_static_args = { + **ray_remote_args, + "num_returns": "streaming", + } + + self._map_task = cached_remote_fn(_map_task, **ray_remote_static_args) + def _add_bundled_input(self, bundle: RefBundle): # Submit the task as a normal Ray task. - map_task = cached_remote_fn(_map_task, num_returns="streaming") ctx = TaskContext( task_idx=self._next_data_task_idx, target_max_block_size=self.actual_target_max_block_size, @@ -82,7 +90,7 @@ def _add_bundled_input(self, bundle: RefBundle): 2 * data_context._max_num_blocks_in_streaming_gen_buffer ) - gen = map_task.options(**ray_remote_args).remote( + gen = self._map_task.options(**ray_remote_args).remote( self._map_transformer_ref, data_context, ctx, From 22955fc6115fb8a29508fd9c53020bed0cf2ebe9 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 14:48:59 -0700 Subject: [PATCH 05/15] `lint` Signed-off-by: Alexey Kudinkin --- .../_internal/execution/operators/task_pool_map_operator.py | 5 +++-- python/ray/data/_internal/remote_fn.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index f01fd575bea3f..454b31cc3e65b 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -63,8 +63,9 @@ def __init__( ) self._concurrency = concurrency - # NOTE: Unlike static Ray remote args, dynamic arguments extracted from the blocks - # themselves are going to be passed inside `fn.options(...)` invocation + # NOTE: Unlike static Ray remote args, dynamic arguments extracted from the + # blocks themselves are going to be passed inside `fn.options(...)` + # invocation ray_remote_static_args = { **ray_remote_args, "num_returns": "streaming", diff --git a/python/ray/data/_internal/remote_fn.py b/python/ray/data/_internal/remote_fn.py index e05616393eca2..aa44f1e564839 100644 --- a/python/ray/data/_internal/remote_fn.py +++ b/python/ray/data/_internal/remote_fn.py @@ -24,7 +24,7 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: # - Sort all KV-pairs by the keys # - Convert sorted list into tuple # - Compute hash of the resulting tuple - arg_pairs = tuple(sorted(list(ray_remote_args.items()), key=lambda t: t[0])) + arg_pairs = tuple(sorted(ray_remote_args.items(), key=lambda t: t[0])) args_hash = hash(arg_pairs) if (fn, args_hash) not in CACHED_FUNCTIONS: @@ -39,7 +39,6 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: ray_remote_args = {**default_ray_remote_args, **ray_remote_args} _add_system_error_to_retry_exceptions(ray_remote_args) - CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn) return CACHED_FUNCTIONS[(fn, args_hash)] From e2b53bc7c4917f50bf7d37c00d97b19fe7aaf547 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 15:06:46 -0700 Subject: [PATCH 06/15] Tidying up Signed-off-by: Alexey Kudinkin --- .../execution/operators/task_pool_map_operator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 454b31cc3e65b..07e316a0d2402 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -79,19 +79,20 @@ def _add_bundled_input(self, bundle: RefBundle): task_idx=self._next_data_task_idx, target_max_block_size=self.actual_target_max_block_size, ) - data_context = DataContext.get_current() - ray_remote_args = self._get_runtime_ray_remote_args(input_bundle=bundle) - ray_remote_args["name"] = self.name + dynamic_ray_remote_args = self._get_runtime_ray_remote_args(input_bundle=bundle) + dynamic_ray_remote_args["name"] = self.name + + data_context = DataContext.get_current() if data_context._max_num_blocks_in_streaming_gen_buffer is not None: # The `_generator_backpressure_num_objects` parameter should be # `2 * _max_num_blocks_in_streaming_gen_buffer` because we yield # 2 objects for each block: the block and the block metadata. - ray_remote_args["_generator_backpressure_num_objects"] = ( + dynamic_ray_remote_args["_generator_backpressure_num_objects"] = ( 2 * data_context._max_num_blocks_in_streaming_gen_buffer ) - gen = self._map_task.options(**ray_remote_args).remote( + gen = self._map_task.options(**dynamic_ray_remote_args).remote( self._map_transformer_ref, data_context, ctx, From 53a5d835fb42aafd213bda61989c02af5e552588 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Fri, 20 Sep 2024 15:20:50 -0700 Subject: [PATCH 07/15] Added E2E test Signed-off-by: Alexey Kudinkin --- python/ray/data/tests/test_map.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 9e01ab3b83c41..8c67b6570be07 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -256,6 +256,26 @@ def __call__(self, x): ds.map_batches(Mapper, concurrency=1).materialize() +def test_gpu_workers_not_reused(shutdown_only): + """By default, in Ray Core if `num_gpus` is specified workers will not be reused + for tasks invocation. + + For more context check out https://github.com/ray-project/ray/issues/29624""" + + ray.init(num_gpus=1) + + total_blocks = 5 + ds = ray.data.range(5, override_num_blocks=total_blocks) + + def _get_worker_id(_): + worker = ray._private.worker.global_worker + return {"worker_id": worker.core_worker.get_worker_id().hex()} + + unique_worker_ids = ds.map(_get_worker_id, num_gpus=1).unique("worker_id") + + assert len(unique_worker_ids) == total_blocks + + def test_concurrency(shutdown_only): ray.init(num_cpus=6) ds = ray.data.range(10, override_num_blocks=10) From 30100ff00af4c8e191cabaeebe4476c253fbf96f Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Mon, 23 Sep 2024 11:57:04 -0700 Subject: [PATCH 08/15] Fixing NPE Signed-off-by: Alexey Kudinkin --- .../_internal/execution/operators/task_pool_map_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 07e316a0d2402..acce7915a8c52 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -67,7 +67,7 @@ def __init__( # blocks themselves are going to be passed inside `fn.options(...)` # invocation ray_remote_static_args = { - **ray_remote_args, + **(ray_remote_args or {}), "num_returns": "streaming", } From 18b4447ef3402eef076610a36f73892f466bd48e Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Mon, 23 Sep 2024 12:02:10 -0700 Subject: [PATCH 09/15] Tidying up Signed-off-by: Alexey Kudinkin --- python/ray/data/tests/test_map.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 8c67b6570be07..675a0a3dd417e 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -268,8 +268,7 @@ def test_gpu_workers_not_reused(shutdown_only): ds = ray.data.range(5, override_num_blocks=total_blocks) def _get_worker_id(_): - worker = ray._private.worker.global_worker - return {"worker_id": worker.core_worker.get_worker_id().hex()} + return {"worker_id": ray.get_runtime_context().get_worker_id()} unique_worker_ids = ds.map(_get_worker_id, num_gpus=1).unique("worker_id") From 94621b5d55b3ccb02be5e6b6bfa7d15c8d75a992 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Tue, 24 Sep 2024 11:31:39 -0700 Subject: [PATCH 10/15] Generalized conversion to `make_hashable` utility Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/remote_fn.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/remote_fn.py b/python/ray/data/_internal/remote_fn.py index aa44f1e564839..ae80855b6945a 100644 --- a/python/ray/data/_internal/remote_fn.py +++ b/python/ray/data/_internal/remote_fn.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, List import ray @@ -24,8 +24,8 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: # - Sort all KV-pairs by the keys # - Convert sorted list into tuple # - Compute hash of the resulting tuple - arg_pairs = tuple(sorted(ray_remote_args.items(), key=lambda t: t[0])) - args_hash = hash(arg_pairs) + hashable_args = _make_hashable(ray_remote_args) + args_hash = hash(hashable_args) if (fn, args_hash) not in CACHED_FUNCTIONS: default_ray_remote_args = { @@ -44,6 +44,18 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: return CACHED_FUNCTIONS[(fn, args_hash)] +def _make_hashable(obj): + if isinstance(obj, (List, tuple)): + return tuple([_make_hashable(o) for o in obj]) + elif isinstance(obj, Dict): + converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()] + return tuple(sorted(converted, key=lambda t: t[0])) + elif isinstance(obj, (bool, int, float, str, bytes, type(None))): + return obj + else: + raise ValueError(f"Type {type(obj)} is not hashable") + + def _add_system_error_to_retry_exceptions(ray_remote_args) -> None: """Modify the remote args so that Ray retries `RaySystemError`s. From 5d42ebb28d3da8e937da5906897faf2ccde28e1d Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Tue, 24 Sep 2024 11:31:44 -0700 Subject: [PATCH 11/15] Added tests Signed-off-by: Alexey Kudinkin --- python/ray/data/tests/test_util.py | 51 +++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index cbc5ed9ac4d3c..91662e9fd878b 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -11,7 +11,7 @@ trace_allocation, trace_deallocation, ) -from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.remote_fn import cached_remote_fn, _make_hashable from ray.data._internal.util import ( _check_pyarrow_version, _split_list, @@ -34,6 +34,55 @@ def foo(): assert cpu_only_foo != gpu_only_foo +def test_make_hashable(): + valid_args = { + "int": 0, + "float": 1.2, + "str": "foo", + "dict": { + 0: 0, + 1.2: 1.2, + }, + "list": list(range(10)), + "tuple": tuple(range(3)), + } + + hashable_args = _make_hashable(valid_args) + + assert hash(hashable_args) == hash(( + ('dict', ((0, 0), (1.2, 1.2))), + ('float', 1.2), + ('int', 0), + ('list', (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), + ('str', 'foo'), + ('tuple', (0, 1, 2)), + )) + + # Invalid case # 1: can't mix up key types + invalid_args = { + 0: 1, + "bar": "baz" + } + + with pytest.raises(TypeError) as exc_info: + _make_hashable(invalid_args) + + assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'" + + # Invalid case # 2: can't use anything but dict, list, tuple or primitive types + class Foo: + bar: 0 + + invalid_args = { + 0: Foo(), + } + + with pytest.raises(ValueError) as exc_info: + _make_hashable(invalid_args) + + assert str(exc_info.value) == "Type .Foo'> is not hashable" + + def test_check_pyarrow_version_bounds(unsupported_pyarrow_version): # Test that pyarrow versions outside of the defined bounds cause an ImportError to # be raised. From 17bdc08ab268f8ded4db43a2512cae8aad264100 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Tue, 24 Sep 2024 11:32:09 -0700 Subject: [PATCH 12/15] `lint` Signed-off-by: Alexey Kudinkin --- python/ray/data/tests/test_util.py | 34 +++++++++++++++++------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 91662e9fd878b..59229b5f94f3c 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -11,7 +11,7 @@ trace_allocation, trace_deallocation, ) -from ray.data._internal.remote_fn import cached_remote_fn, _make_hashable +from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn from ray.data._internal.util import ( _check_pyarrow_version, _split_list, @@ -49,25 +49,26 @@ def test_make_hashable(): hashable_args = _make_hashable(valid_args) - assert hash(hashable_args) == hash(( - ('dict', ((0, 0), (1.2, 1.2))), - ('float', 1.2), - ('int', 0), - ('list', (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), - ('str', 'foo'), - ('tuple', (0, 1, 2)), - )) + assert hash(hashable_args) == hash( + ( + ("dict", ((0, 0), (1.2, 1.2))), + ("float", 1.2), + ("int", 0), + ("list", (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), + ("str", "foo"), + ("tuple", (0, 1, 2)), + ) + ) # Invalid case # 1: can't mix up key types - invalid_args = { - 0: 1, - "bar": "baz" - } + invalid_args = {0: 1, "bar": "baz"} with pytest.raises(TypeError) as exc_info: _make_hashable(invalid_args) - assert str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'" + assert ( + str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'" + ) # Invalid case # 2: can't use anything but dict, list, tuple or primitive types class Foo: @@ -80,7 +81,10 @@ class Foo: with pytest.raises(ValueError) as exc_info: _make_hashable(invalid_args) - assert str(exc_info.value) == "Type .Foo'> is not hashable" + assert ( + str(exc_info.value) + == "Type .Foo'> is not hashable" + ) def test_check_pyarrow_version_bounds(unsupported_pyarrow_version): From 277e0a03201985b9067e90177fdb3a66127a021f Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Tue, 24 Sep 2024 14:17:08 -0700 Subject: [PATCH 13/15] Add support for more generic `Hashable` args Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/remote_fn.py | 4 ++-- python/ray/data/tests/test_util.py | 19 +++---------------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/python/ray/data/_internal/remote_fn.py b/python/ray/data/_internal/remote_fn.py index ae80855b6945a..511604c0bd2e6 100644 --- a/python/ray/data/_internal/remote_fn.py +++ b/python/ray/data/_internal/remote_fn.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, Hashable, List import ray @@ -50,7 +50,7 @@ def _make_hashable(obj): elif isinstance(obj, Dict): converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()] return tuple(sorted(converted, key=lambda t: t[0])) - elif isinstance(obj, (bool, int, float, str, bytes, type(None))): + elif isinstance(obj, Hashable): return obj else: raise ValueError(f"Type {type(obj)} is not hashable") diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 59229b5f94f3c..b66a9bc5804f8 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from typing_extensions import Hashable import ray from ray.data._internal.datasource.parquet_datasource import ParquetDatasource @@ -45,6 +46,7 @@ def test_make_hashable(): }, "list": list(range(10)), "tuple": tuple(range(3)), + "type": Hashable, } hashable_args = _make_hashable(valid_args) @@ -57,6 +59,7 @@ def test_make_hashable(): ("list", (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), ("str", "foo"), ("tuple", (0, 1, 2)), + ("type", Hashable), ) ) @@ -70,22 +73,6 @@ def test_make_hashable(): str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'" ) - # Invalid case # 2: can't use anything but dict, list, tuple or primitive types - class Foo: - bar: 0 - - invalid_args = { - 0: Foo(), - } - - with pytest.raises(ValueError) as exc_info: - _make_hashable(invalid_args) - - assert ( - str(exc_info.value) - == "Type .Foo'> is not hashable" - ) - def test_check_pyarrow_version_bounds(unsupported_pyarrow_version): # Test that pyarrow versions outside of the defined bounds cause an ImportError to From 4711bfe72081b44a56feb882a93a8ba2611c3ab5 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Thu, 26 Sep 2024 21:06:23 -0700 Subject: [PATCH 14/15] Use appropriately initialized base field Signed-off-by: Alexey Kudinkin --- .../_internal/execution/operators/task_pool_map_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index acce7915a8c52..b565c38e56628 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -67,7 +67,7 @@ def __init__( # blocks themselves are going to be passed inside `fn.options(...)` # invocation ray_remote_static_args = { - **(ray_remote_args or {}), + **(self._ray_remote_args or {}), "num_returns": "streaming", } From 65a1b2875e1b49b0e4dbf77c7052ef01687901fa Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Thu, 26 Sep 2024 21:06:55 -0700 Subject: [PATCH 15/15] Fixed operator fusion to properly handle `scheduling_strategy`; Fixed tests Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/logical/rules/operator_fusion.py | 5 ++++- python/ray/data/tests/test_execution_optimizer.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 79fd67ed64de0..30fdb2dc4fa11 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -437,8 +437,11 @@ def _are_remote_args_compatible(prev_args, next_args): next_args = _canonicalize(next_args) remote_args = next_args.copy() for key in INHERITABLE_REMOTE_ARGS: - if key in prev_args: + # NOTE: We only carry over inheritable value in case + # of it not being provided in the remote args + if key in prev_args and key not in remote_args: remote_args[key] = prev_args[key] + if prev_args != remote_args: return False return True diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 1712fa1a09620..913c7ee1822bf 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -661,7 +661,7 @@ def test_read_map_batches_operator_fusion_incompatible_remote_args( ray_start_regular_shared, ): # Test that map operators won't get fused if the remote args are incompatible. - incompatiple_remote_args_pairs = [ + incompatible_remote_args_pairs = [ # Use different resources. ({"num_cpus": 2}, {"num_gpus": 2}), # Same resource, but different values. @@ -670,9 +670,9 @@ def test_read_map_batches_operator_fusion_incompatible_remote_args( ({"resources": {"custom": 2}}, {"resources": {"custom": 1}}), ({"resources": {"custom1": 1}}, {"resources": {"custom2": 1}}), # Different scheduling strategies. - ({"scheduling_strategy": "SPREAD"}, {"scheduing_strategy": "PACK"}), + ({"scheduling_strategy": "SPREAD"}, {"scheduling_strategy": "PACK"}), ] - for up_remote_args, down_remote_args in incompatiple_remote_args_pairs: + for up_remote_args, down_remote_args in incompatible_remote_args_pairs: planner = Planner() read_op = get_parquet_read_logical_op( ray_remote_args={"resources": {"non-existent": 1}}