Skip to content

Commit

Permalink
[Core][aDAG] support multi readers in multi node when dag is created …
Browse files Browse the repository at this point in the history
…from an actor (ray-project#47601)

Currently, when a DAG is created from an actor, we are using different mechanism from a driver. In a driver we create a ProxyActor vs actor we are just using the actor itself.

This inconsistent mechanism is prone to error. As an example, I found when we support multi reader in multi node, we have deadlock because the driver actor needs to call ray.get(a.allocate_channel.remote()) for a downstream actor while the downstream actor calls ray.get(driver_actor.create_ref.remote()).

This fixes the issue by making ProxyActor as the default mechanism even when a dag is created inside an actor.

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
rkooo567 authored and ujjawal-khare committed Oct 15, 2024
1 parent 89b0fc7 commit 8779e4a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
6 changes: 3 additions & 3 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,9 +1277,9 @@ def _get_or_compile(
input_task = self.idx_to_task[self.input_task_idx]
# Register custom serializers for inputs provided to dag.execute().
input_task.dag_node.type_hint.register_custom_serializer()
self.dag_input_channels = input_task.output_channels
assert self.dag_input_channels is not None

assert len(input_task.output_channels) == 1
self.dag_input_channel = input_task.output_channels[0]
assert self.dag_input_channel is not None
# Create executable tasks for each actor
for actor_handle, tasks in self.actor_to_tasks.items():
# Dict from non-dag-input arg to the set of tasks that consume it.
Expand Down
44 changes: 44 additions & 0 deletions python/ray/dag/tests/experimental/test_multi_node_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,50 @@ def _get_node_id(self) -> "ray.NodeID":
compiled_dag.teardown()


def test_multi_node_dag_from_actor(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=1)
ray.init()
cluster.add_node(num_cpus=1)

@ray.remote(num_cpus=0)
class SameNodeActor:
def predict(self, x: str):
return x

@ray.remote(num_cpus=1)
class RemoteNodeActor:
def predict(self, x: str, y: str):
return y

@ray.remote(num_cpus=1)
class DriverActor:
def __init__(self):
self._base_actor = SameNodeActor.options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(), soft=False
)
).remote()
self._refiner_actor = RemoteNodeActor.remote()

with InputNode() as inp:
x = self._base_actor.predict.bind(inp)
dag = self._refiner_actor.predict.bind(
inp,
x,
)

self._adag = dag.experimental_compile(
_execution_timeout=120,
)

def call(self, prompt: str) -> bytes:
return ray.get(self._adag.execute(prompt))

parallel = DriverActor.remote()
assert ray.get(parallel.call.remote("abc")) == "abc"


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
9 changes: 8 additions & 1 deletion python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def create_channel(
cpu_data_typ=cpu_data_typ,
)

return CompositeChannel(writer, reader_and_node_list, self._num_shm_buffers)
return CompositeChannel(
writer,
reader_and_node_list,
self._num_shm_buffers,
read_by_adag_driver,
)

def set_nccl_group_id(self, group_id: str) -> None:
assert self.requires_nccl()
Expand Down Expand Up @@ -652,6 +657,7 @@ def __init__(
writer: Optional[ray.actor.ActorHandle],
reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
num_shm_buffers: int,
read_by_adag_driver: bool,
_channel_dict: Optional[Dict[ray.ActorID, ChannelInterface]] = None,
_channels: Optional[Set[ChannelInterface]] = None,
_writer_registered: bool = False,
Expand Down Expand Up @@ -738,6 +744,7 @@ def __reduce__(self):
self._writer,
self._reader_and_node_list,
self._num_shm_buffers,
self._read_by_adag_driver,
self._channel_dict,
self._channels,
self._writer_registered,
Expand Down
18 changes: 13 additions & 5 deletions python/ray/tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,8 +906,12 @@ def __init__(self):
def pass_channel(self, channel):
self._chan = channel

def create_composite_channel(self, writer, reader_and_node_list):
self._chan = ray_channel.CompositeChannel(writer, reader_and_node_list, 10)
def create_composite_channel(
self, writer, reader_and_node_list, read_by_adag_driver
):
self._chan = ray_channel.CompositeChannel(
writer, reader_and_node_list, 10, read_by_adag_driver
)
return self._chan

def read(self):
Expand All @@ -922,7 +926,9 @@ def write(self, value):
node2 = get_actor_node_id(actor2)

# Create a channel to communicate between driver process and actor1.
driver_to_actor1_channel = ray_channel.CompositeChannel(None, [(actor1, node1)], 10)
driver_to_actor1_channel = ray_channel.CompositeChannel(
None, [(actor1, node1)], 10, False
)
ray.get(actor1.pass_channel.remote(driver_to_actor1_channel))
driver_to_actor1_channel.write("hello")
assert ray.get(actor1.read.remote()) == "hello"
Expand Down Expand Up @@ -979,7 +985,9 @@ def pass_channel(self, channel):
self._chan = channel

def create_composite_channel(self, writer, reader_and_node_list):
self._chan = ray_channel.CompositeChannel(writer, reader_and_node_list, 10)
self._chan = ray_channel.CompositeChannel(
writer, reader_and_node_list, 10, False
)
return self._chan

def read(self):
Expand All @@ -995,7 +1003,7 @@ def write(self, value):

# The driver writes data to CompositeChannel and actor1 and actor2 read it.
driver_output_channel = ray_channel.CompositeChannel(
None, [(actor1, node1), (actor2, node2)], 10
None, [(actor1, node1), (actor2, node2)], 10, False
)
ray.get(actor1.pass_channel.remote(driver_output_channel))
ray.get(actor2.pass_channel.remote(driver_output_channel))
Expand Down

0 comments on commit 8779e4a

Please sign in to comment.