Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Make sure num_gpus provide to Ray Data is appropriately passed to ray.remote call #47768

Merged
merged 15 commits into from
Sep 27, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,36 @@ def __init__(
)
self._concurrency = concurrency

# NOTE: Unlike static Ray remote args, dynamic arguments extracted from the
# blocks themselves are going to be passed inside `fn.options(...)`
# invocation
ray_remote_static_args = {
**(self._ray_remote_args or {}),
"num_returns": "streaming",
}

self._map_task = cached_remote_fn(_map_task, **ray_remote_static_args)

def _add_bundled_input(self, bundle: RefBundle):
# Submit the task as a normal Ray task.
map_task = cached_remote_fn(_map_task, num_returns="streaming")
ctx = TaskContext(
task_idx=self._next_data_task_idx,
target_max_block_size=self.actual_target_max_block_size,
)
data_context = DataContext.get_current()
ray_remote_args = self._get_runtime_ray_remote_args(input_bundle=bundle)
ray_remote_args["name"] = self.name

dynamic_ray_remote_args = self._get_runtime_ray_remote_args(input_bundle=bundle)
dynamic_ray_remote_args["name"] = self.name

data_context = DataContext.get_current()
if data_context._max_num_blocks_in_streaming_gen_buffer is not None:
# The `_generator_backpressure_num_objects` parameter should be
# `2 * _max_num_blocks_in_streaming_gen_buffer` because we yield
# 2 objects for each block: the block and the block metadata.
ray_remote_args["_generator_backpressure_num_objects"] = (
dynamic_ray_remote_args["_generator_backpressure_num_objects"] = (
2 * data_context._max_num_blocks_in_streaming_gen_buffer
)

gen = map_task.options(**ray_remote_args).remote(
gen = self._map_task.options(**dynamic_ray_remote_args).remote(
self._map_transformer_ref,
data_context,
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,11 @@ def _are_remote_args_compatible(prev_args, next_args):
next_args = _canonicalize(next_args)
remote_args = next_args.copy()
for key in INHERITABLE_REMOTE_ARGS:
if key in prev_args:
# NOTE: We only carry over inheritable value in case
# of it not being provided in the remote args
if key in prev_args and key not in remote_args:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was incorrectly handling scheduling_strategy

remote_args[key] = prev_args[key]

if prev_args != remote_args:
return False
return True
Expand Down
35 changes: 30 additions & 5 deletions python/ray/data/_internal/remote_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, Hashable, List

import ray

Expand All @@ -12,11 +12,22 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
(ray imports ray.data in order to allow ``ray.data.read_foo()`` to work,
which means ray.remote cannot be used top-level in ray.data).

Note: Dynamic arguments should not be passed in directly,
NOTE: Dynamic arguments should not be passed in directly,
and should be set with ``options`` instead:
``cached_remote_fn(fn, **static_args).options(**dynamic_args)``.
"""
if fn not in CACHED_FUNCTIONS:

# NOTE: Hash of the passed in arguments guarantees that we're caching
# complete instantiation of the Ray's remote method
#
# To compute the hash of passed in arguments and make sure it's deterministic
# - Sort all KV-pairs by the keys
# - Convert sorted list into tuple
# - Compute hash of the resulting tuple
hashable_args = _make_hashable(ray_remote_args)
args_hash = hash(hashable_args)

if (fn, args_hash) not in CACHED_FUNCTIONS:
default_ray_remote_args = {
# Use the default scheduling strategy for all tasks so that we will
# not inherit a placement group from the caller, if there is one.
Expand All @@ -27,8 +38,22 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
}
ray_remote_args = {**default_ray_remote_args, **ray_remote_args}
_add_system_error_to_retry_exceptions(ray_remote_args)
CACHED_FUNCTIONS[fn] = ray.remote(**ray_remote_args)(fn)
return CACHED_FUNCTIONS[fn]

CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn)

return CACHED_FUNCTIONS[(fn, args_hash)]


def _make_hashable(obj):
if isinstance(obj, (List, tuple)):
return tuple([_make_hashable(o) for o in obj])
elif isinstance(obj, Dict):
converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()]
return tuple(sorted(converted, key=lambda t: t[0]))
elif isinstance(obj, Hashable):
return obj
else:
raise ValueError(f"Type {type(obj)} is not hashable")


def _add_system_error_to_retry_exceptions(ray_remote_args) -> None:
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_read_map_batches_operator_fusion_incompatible_remote_args(
ray_start_regular_shared,
):
# Test that map operators won't get fused if the remote args are incompatible.
incompatiple_remote_args_pairs = [
incompatible_remote_args_pairs = [
# Use different resources.
({"num_cpus": 2}, {"num_gpus": 2}),
# Same resource, but different values.
Expand All @@ -670,9 +670,9 @@ def test_read_map_batches_operator_fusion_incompatible_remote_args(
({"resources": {"custom": 2}}, {"resources": {"custom": 1}}),
({"resources": {"custom1": 1}}, {"resources": {"custom2": 1}}),
# Different scheduling strategies.
({"scheduling_strategy": "SPREAD"}, {"scheduing_strategy": "PACK"}),
({"scheduling_strategy": "SPREAD"}, {"scheduling_strategy": "PACK"}),
]
for up_remote_args, down_remote_args in incompatiple_remote_args_pairs:
for up_remote_args, down_remote_args in incompatible_remote_args_pairs:
planner = Planner()
read_op = get_parquet_read_logical_op(
ray_remote_args={"resources": {"non-existent": 1}}
Expand Down
19 changes: 19 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,25 @@ def __call__(self, x):
ds.map_batches(Mapper, concurrency=1).materialize()


def test_gpu_workers_not_reused(shutdown_only):
"""By default, in Ray Core if `num_gpus` is specified workers will not be reused
for tasks invocation.

For more context check out https://github.com/ray-project/ray/issues/29624"""

ray.init(num_gpus=1)

total_blocks = 5
ds = ray.data.range(5, override_num_blocks=total_blocks)

def _get_worker_id(_):
return {"worker_id": ray.get_runtime_context().get_worker_id()}

unique_worker_ids = ds.map(_get_worker_id, num_gpus=1).unique("worker_id")

assert len(unique_worker_ids) == total_blocks


def test_concurrency(shutdown_only):
ray.init(num_cpus=6)
ds = ray.data.range(10, override_num_blocks=10)
Expand Down
55 changes: 55 additions & 0 deletions python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from typing_extensions import Hashable

import ray
from ray.data._internal.datasource.parquet_datasource import ParquetDatasource
Expand All @@ -11,6 +12,7 @@
trace_allocation,
trace_deallocation,
)
from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn
from ray.data._internal.util import (
_check_pyarrow_version,
_split_list,
Expand All @@ -19,6 +21,59 @@
from ray.data.tests.conftest import * # noqa: F401, F403


def test_cached_remote_fn():
def foo():
pass

cpu_only_foo = cached_remote_fn(foo, num_cpus=1)
cached_cpu_only_foo = cached_remote_fn(foo, num_cpus=1)

assert cpu_only_foo == cached_cpu_only_foo

gpu_only_foo = cached_remote_fn(foo, num_gpus=1)

assert cpu_only_foo != gpu_only_foo
alexeykudinkin marked this conversation as resolved.
Show resolved Hide resolved


def test_make_hashable():
valid_args = {
"int": 0,
"float": 1.2,
"str": "foo",
"dict": {
0: 0,
1.2: 1.2,
},
"list": list(range(10)),
"tuple": tuple(range(3)),
"type": Hashable,
}

hashable_args = _make_hashable(valid_args)

assert hash(hashable_args) == hash(
(
("dict", ((0, 0), (1.2, 1.2))),
("float", 1.2),
("int", 0),
("list", (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)),
("str", "foo"),
("tuple", (0, 1, 2)),
("type", Hashable),
)
)

# Invalid case # 1: can't mix up key types
invalid_args = {0: 1, "bar": "baz"}

with pytest.raises(TypeError) as exc_info:
_make_hashable(invalid_args)

assert (
str(exc_info.value) == "'<' not supported between instances of 'str' and 'int'"
)


def test_check_pyarrow_version_bounds(unsupported_pyarrow_version):
# Test that pyarrow versions outside of the defined bounds cause an ImportError to
# be raised.
Expand Down