Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aDAG] Allow custom NCCL group for aDAG #47141

Merged
merged 8 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import NamedTuple

from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -640,6 +641,11 @@ def __init__(
# Type hints specified by the user for DAG (intermediate) outputs.
self._type_hints = []

# This is set to true when type hint of `transport="nccl"`` is used
self._use_default_nccl_group = False
# This is set to the specified custom nccl group
# if there exists a type hint of `transport=nccl_group`
self._custom_nccl_group: Optional[GPUCommunicator] = None
# Uniquely identifies the NCCL communicator that will be used within
# this DAG, if any.
self._nccl_group_id: Optional[str] = None
Expand Down Expand Up @@ -806,6 +812,33 @@ def _preprocess(self) -> None:
if dag_node.type_hint.requires_nccl():
# Add all writers to the NCCL group.
nccl_actors.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
mixed_nccl_group_error_message = (
"Accelerated DAGs do not support mixed usage of "
"type hints of default NCCL group "
'(i.e., TorchTensor(transport="nccl"))'
"and custom NCCL group "
"(i.e., TorchTensor(transport=nccl_group)). "
"Please check all the TorchTensor type hints and "
"make sure only one type of NCCL transport is specified."
)
if custom_nccl_group is None:
if self._custom_nccl_group is not None:
raise ValueError(mixed_nccl_group_error_message)
self._use_default_nccl_group = True
else:
if self._use_default_nccl_group:
raise ValueError(mixed_nccl_group_error_message)
if self._custom_nccl_group is not None:
if self._custom_nccl_group != custom_nccl_group:
raise ValueError(
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified. Check all the "
"TorchTensor(transport=nccl_group) type hints "
"to make sure only one NCCL group is used."
)
self._custom_nccl_group = custom_nccl_group
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
raise ValueError(
Expand Down Expand Up @@ -916,7 +949,7 @@ def _preprocess(self) -> None:
if None in nccl_actors:
raise ValueError("Driver cannot participate in the NCCL group.")
if nccl_actors and self._nccl_group_id is None:
self._nccl_group_id = _init_nccl_group(nccl_actors)
self._nccl_group_id = _init_nccl_group(nccl_actors, self._custom_nccl_group)

if direct_input:
self._input_num_positional_args = 1
Expand Down
Loading
Loading