Skip to content

Commit

Permalink
[Datasets] Run _get_read_tasks with NodeAffinitySchedulingStrategy (r…
Browse files Browse the repository at this point in the history
…ay-project#33212)

This PR is to change `_get_read_tasks` always running as a Ray task at same node. For Ray client, the task will be executed on head node. The motivation for this PR is to simplify the logic so we don't run `_get_read_tasks` on arbitrary node.

Co-authored-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: elliottower <elliot@elliottower.com>
  • Loading branch information
2 people authored and elliottower committed Apr 22, 2023
1 parent 9534a99 commit 0174b17
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
4 changes: 4 additions & 0 deletions python/ray/data/_internal/remote_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
This is used in Datasets to avoid circular import issues with ray.remote.
(ray imports ray.data in order to allow ``ray.data.read_foo()`` to work,
which means ray.remote cannot be used top-level in ray.data).
Note: Dynamic arguments should not be passed in directly,
and should be set with ``options`` instead:
``cached_remote_fn(fn, **static_args).options(**dynamic_args)``.
"""
if fn not in CACHED_FUNCTIONS:
ctx = DatasetContext.get_current()
Expand Down
11 changes: 8 additions & 3 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,16 @@ def read_datasource(
datasource, ctx, cur_pg, parallelism, local_uri, read_args
)
else:
# Prepare read in a remote task so that in Ray client mode, we aren't
# attempting metadata resolution from the client machine.
# Prepare read in a remote task at same node.
# NOTE: in Ray client mode, this is expected to be run on head node.
# So we aren't attempting metadata resolution from the client machine.
scheduling_strategy = NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(),
soft=False,
)
get_read_tasks = cached_remote_fn(
_get_read_tasks, retry_exceptions=False, num_cpus=0
)
).options(scheduling_strategy=scheduling_strategy)

requested_parallelism, min_safe_parallelism, read_tasks = ray.get(
get_read_tasks.remote(
Expand Down
36 changes: 35 additions & 1 deletion python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fsspec.implementations.http import HTTPFileSystem

import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
Expand Down Expand Up @@ -305,7 +306,7 @@ def get_node_id():
assert node_ids == {bar_node_id}


def test_read_s3_file_error(ray_start_regular_shared, s3_path):
def test_read_s3_file_error(shutdown_only, s3_path):
dummy_path = s3_path + "_dummy"
error_message = "Please check that file exists and has properly configured access."
with pytest.raises(OSError, match=error_message):
Expand All @@ -324,6 +325,39 @@ def test_read_s3_file_error(ray_start_regular_shared, s3_path):
_handle_read_os_error(error, dummy_path)


# NOTE: All tests above share a Ray cluster, while the tests below do not. These
# tests should only be carefully reordered to retain this invariant!


def test_get_read_tasks(ray_start_cluster):
ray.shutdown()
cluster = ray_start_cluster
cluster.add_node(num_cpus=4)
cluster.add_node(num_cpus=4)
cluster.wait_for_nodes()
ray.init(cluster.address)

head_node_id = ray.get_runtime_context().get_node_id()

# Issue read so `_get_read_tasks` being executed.
ray.data.range(10).fully_executed()

# Verify `_get_read_tasks` being executed on same node (head node).
def verify_get_read_tasks():
from ray.experimental.state.api import list_tasks

task_states = list_tasks(
address=cluster.address, filters=[("name", "=", "_get_read_tasks")]
)
# Verify only one task being executed on same node.
assert len(task_states) == 1
assert task_states[0]["name"] == "_get_read_tasks"
assert task_states[0]["node_id"] == head_node_id
return True

wait_for_condition(verify_get_read_tasks, timeout=20)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 0174b17

Please sign in to comment.