From e9c6563ed4d29395753099a51d920ecde3901278 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Tue, 10 Sep 2024 09:27:42 -0700 Subject: [PATCH 01/16] WIP Signed-off-by: Rui Qiao --- .buildkite/core.rayci.yml | 4 +- python/ray/dag/__init__.py | 2 + python/ray/dag/compiled_dag_node.py | 218 ++++++++++++--- python/ray/dag/constants.py | 3 + python/ray/dag/context.py | 11 + python/ray/dag/dag_node.py | 7 + python/ray/dag/dag_node_operation.py | 251 +++++++++++++++-- python/ray/dag/dag_operation_future.py | 86 ++++++ .../experimental/test_execution_schedule.py | 16 +- .../test_execution_schedule_gpu.py | 75 +++++ .../experimental/test_torch_tensor_dag.py | 261 +++++++++++++++++- python/ray/experimental/channel/common.py | 1 + .../experimental/channel/gpu_communicator.py | 18 ++ python/ray/experimental/channel/nccl_group.py | 81 +++++- .../channel/torch_tensor_nccl_channel.py | 7 + 15 files changed, 958 insertions(+), 83 deletions(-) create mode 100644 python/ray/dag/dag_operation_future.py diff --git a/.buildkite/core.rayci.yml b/.buildkite/core.rayci.yml index dcbebd735c3e..09005593c319 100644 --- a/.buildkite/core.rayci.yml +++ b/.buildkite/core.rayci.yml @@ -376,10 +376,8 @@ steps: - gpu instance_type: gpu-large commands: - # This machine has 4 GPUs, and we need 2 GPUs, so allow 2 tests to run in - # parallel. - bazel run //ci/ray_ci:test_in_docker -- //python/ray/tests/... //python/ray/dag/... core - --parallelism-per-worker 2 --gpus 2 + --gpus 4 --build-name coregpubuild --only-tags multi_gpu depends_on: coregpubuild diff --git a/python/ray/dag/__init__.py b/python/ray/dag/__init__.py index eb13abc5a53e..bc081be76e50 100644 --- a/python/ray/dag/__init__.py +++ b/python/ray/dag/__init__.py @@ -11,6 +11,7 @@ DAGInputData, ) from ray.dag.output_node import MultiOutputNode +from ray.dag.dag_operation_future import DAGOperationFuture from ray.dag.constants import ( PARENT_CLASS_NODE_KEY, PREV_CLASS_METHOD_CALL_KEY, @@ -27,6 +28,7 @@ "ClassMethodNode", "CollectiveOutputNode", "DAGNode", + "DAGOperationFuture", "FunctionNode", "InputNode", "InputAttributeNode", diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 2b4bfa68f30a..4b98633103f8 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1,6 +1,7 @@ import weakref import asyncio from collections import defaultdict +from contextlib import nullcontext from dataclasses import dataclass, asdict from typing import Any, Dict, FrozenSet, List, Tuple, Union, Optional, Set import logging @@ -10,6 +11,7 @@ import traceback import ray.exceptions +from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture from ray.experimental.channel.cached_channel import CachedChannel from ray.experimental.channel.gpu_communicator import GPUCommunicator import ray @@ -20,6 +22,7 @@ _process_return_vals, ) from ray.experimental.channel import ( + ChannelContext, ChannelInterface, ChannelOutputType, ReaderInterface, @@ -34,6 +37,7 @@ from ray.experimental.channel.shared_memory_channel import ( SharedMemoryType, + TorchTensorType, ) from ray.experimental.channel.torch_tensor_nccl_channel import ( @@ -46,7 +50,10 @@ _DAGNodeOperationType, _DAGOperationGraphNode, _build_dag_node_operation_graph, + _extract_execution_schedule, _generate_actor_to_execution_schedule, + _generate_overlapped_execution_schedule, + _visualize_execution_schedule, ) from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -114,6 +121,7 @@ def do_exec_tasks( self, tasks: List["ExecutableTask"], schedule: List[_DAGNodeOperation], + overlap_gpu_communication: bool = False, ) -> None: """A generic actor method to begin executing the operations belonging to an actor. This runs an infinite loop to execute each _DAGNodeOperation in the @@ -123,6 +131,8 @@ def do_exec_tasks( Args: tasks: the executable tasks corresponding to the actor methods. schedule: A list of _DAGNodeOperation that should be executed in order. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. """ try: for task in tasks: @@ -134,7 +144,7 @@ def do_exec_tasks( break for operation in schedule: done = tasks[operation.exec_task_idx].exec_operation( - self, operation.type + self, operation.type, overlap_gpu_communication ) if done: break @@ -148,12 +158,15 @@ def do_profile_tasks( self, tasks: List["ExecutableTask"], schedule: List[_DAGNodeOperation], + overlap_gpu_communication: bool = False, ) -> None: """A generic actor method similar to `do_exec_tasks`, but with profiling enabled. Args: tasks: the executable tasks corresponding to the actor methods. schedule: A list of _DAGNodeOperation that should be executed in order. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. """ try: for task in tasks: @@ -170,7 +183,7 @@ def do_profile_tasks( start_t = time.perf_counter() task = tasks[operation.exec_task_idx] done = tasks[operation.exec_task_idx].exec_operation( - self, operation.type + self, operation.type, overlap_gpu_communication ) end_t = time.perf_counter() @@ -213,6 +226,27 @@ def _wrap_exception(exc): return wrapped +def _get_nccl_group_id(type_hint: ChannelOutputType) -> Optional[str]: + """ + Get the NCCL group ID from the type hint. If the type hint does not + require NCCL, return None. + + Args: + type_hint: The type hint of the channel. + + Returns: + The NCCL group ID if the type hint requires NCCL, otherwise None. + """ + if type_hint.requires_nccl(): + if isinstance(type_hint, SharedMemoryType): + assert type_hint._contains_type.requires_nccl() + return _get_nccl_group_id(type_hint._contains_type) + else: + assert isinstance(type_hint, TorchTensorType) + return type_hint.nccl_group_id + return None + + @DeveloperAPI class CompiledTask: """Wraps the normal Ray DAGNode with some metadata.""" @@ -378,10 +412,11 @@ def __init__( self.output_writer: WriterInterface = SynchronousWriter( self.output_channels, self.output_idxs ) - # Store the intermediate result of a READ or COMPUTE operation. + # The intermediate future for a READ or COMPUTE operation, + # and `wait()` must be called to get the actual result of the operation. # The result of a READ operation will be used by a COMPUTE operation, # and the result of a COMPUTE operation will be used by a WRITE operation. - self._intermediate_buffer: Any = None + self._intermediate_future: Optional[DAGOperationFuture] = None def cancel(self): """ @@ -403,46 +438,92 @@ def prepare(self): self.input_reader.start() self.output_writer.start() - def set_intermediate_buffer(self, data: Any): + import cupy as cp + + self._send_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() + self._recv_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() + if self.output_type_hint.requires_nccl(): + nccl_group_id = _get_nccl_group_id(self.output_type_hint) + nccl_group = ChannelContext.get_current().nccl_groups.get(nccl_group_id) + assert nccl_group is not None + self._send_stream = nccl_group.send_stream + if self.input_type_hints: + for type_hint in self.input_type_hints: + if type_hint.requires_nccl(): + nccl_group_id = _get_nccl_group_id(type_hint) + nccl_group = ChannelContext.get_current().nccl_groups.get(nccl_group_id) + assert nccl_group is not None + if not isinstance(self._recv_stream, nullcontext): + assert self._recv_stream == nccl_group.recv_stream, ( + "Currently all torch tensor input channels of a " + "Compiled Graph task should use the same recv cuda stream." + ) + self._recv_stream = nccl_group.recv_stream + + def wrap_and_set_intermediate_future( + self, val: Any, wrap_in_gpu_future: bool + ) -> None: """ - Store the intermediate result of a READ or COMPUTE operation. + Wrap the value in a `DAGOperationFuture` and store to the intermediate future. + The value corresponds to result of a READ or COMPUTE operation. + + If wrap_in_gpu_future is True, the value will be wrapped in a _GPUFuture, + Otherwise, the future will be a ResolvedFuture. Args: - data: The intermediate result of a READ or COMPUTE operation. + val: The value to wrap in a future. + wrap_in_gpu_future: Whether to wrap the value in a _GPUFuture. """ - assert self._intermediate_buffer is None - self._intermediate_buffer = data + assert self._intermediate_future is None - def reset_intermediate_buffer(self) -> Any: + if wrap_in_gpu_future: + future = GPUFuture(val) + else: + future = ResolvedFuture(val) + self._intermediate_future = future + + def reset_and_wait_intermediate_future(self) -> Any: """ - Retrieve the intermediate result of a READ or COMPUTE operation, - and reset the intermediate buffer to None. + Reset the intermediate future and wait for the result. Returns: - The intermediate result of a READ or COMPUTE operation. + The result of a READ or COMPUTE operation from the intermediate future. """ - data = self._intermediate_buffer - self._intermediate_buffer = None - return data + future = self._intermediate_future + self._intermediate_future = None + return future.wait() - def _read(self) -> bool: + def _read(self, overlap_gpu_communication: bool) -> bool: """ Read input data from upstream DAG nodes and cache the intermediate result. + Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. + Returns: True if system error occurs and exit the loop; otherwise, False. """ - assert self._intermediate_buffer is None + assert self._intermediate_future is None exit = False try: input_data = self.input_reader.read() - self.set_intermediate_buffer(input_data) + # When overlap_gpu_communication is enabled, wrap the result in + # a GPU future so that this read operation (communication) can + # be overlapped with computation. + self.wrap_and_set_intermediate_future( + input_data, wrap_in_gpu_future=overlap_gpu_communication + ) except RayChannelError: # Channel closed. Exit the loop. exit = True return exit - def _compute(self, class_handle) -> bool: + def _compute( + self, + overlap_gpu_communication: bool, + class_handle, + ) -> bool: """ Retrieve the intermediate result from the READ operation and perform the computation. Then, cache the new intermediate result. The caller must ensure @@ -450,14 +531,15 @@ def _compute(self, class_handle) -> bool: correct intermediate result. Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. class_handle: An instance of the class to which the actor belongs. For example, the type of `class_handle` is if the actor belongs to the `class Worker` class. - Returns: True if system error occurs and exit the loop; otherwise, False. """ - input_data = self.reset_intermediate_buffer() + input_data = self.reset_and_wait_intermediate_future() try: _process_return_vals(input_data, return_single_output=False) except Exception as exc: @@ -465,7 +547,9 @@ def _compute(self, class_handle) -> bool: # Propagate it and skip the actual task. We don't need to wrap the # exception in a RayTaskError here because it has already been wrapped # by the previous task. - self.set_intermediate_buffer(exc) + self.wrap_and_set_intermediate_future( + exc, wrap_in_gpu_future=overlap_gpu_communication + ) return False resolved_inputs = [] @@ -482,7 +566,12 @@ def _compute(self, class_handle) -> bool: output_val = method(*resolved_inputs, **self.resolved_kwargs) except Exception as exc: output_val = _wrap_exception(exc) - self.set_intermediate_buffer(output_val) + + # When overlap_gpu_communication is enabled, wrap the result in a GPU future + # so that this compute operation can be overlapped with communication. + self.wrap_and_set_intermediate_future( + output_val, wrap_in_gpu_future=overlap_gpu_communication + ) return False def _write(self) -> bool: @@ -494,7 +583,7 @@ def _write(self) -> bool: Returns: True if system error occurs and exit the loop; otherwise, False. """ - output_val = self.reset_intermediate_buffer() + output_val = self.reset_and_wait_intermediate_future() exit = False try: self.output_writer.write(output_val) @@ -507,27 +596,30 @@ def exec_operation( self, class_handle, op_type: _DAGNodeOperationType, + overlap_gpu_communication: bool = False, ) -> bool: """ An ExecutableTask corresponds to a DAGNode. It consists of three operations: READ, COMPUTE, and WRITE, which should be executed in order to ensure that each operation can read the correct intermediate result. - Args: class_handle: The handle of the class to which the actor belongs. op_type: The type of the operation. Possible types are READ, COMPUTE, and WRITE. - + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance. Returns: True if the next operation should not be executed; otherwise, False. """ if op_type == _DAGNodeOperationType.READ: - return self._read() + with self._recv_stream: + return self._read(overlap_gpu_communication) elif op_type == _DAGNodeOperationType.COMPUTE: - return self._compute(class_handle) + return self._compute(overlap_gpu_communication, class_handle) elif op_type == _DAGNodeOperationType.WRITE: - return self._write() + with self._send_stream: + return self._write() @dataclass @@ -581,6 +673,7 @@ def __init__( asyncio_max_queue_size: Optional[int] = None, max_buffered_results: Optional[int] = None, max_inflight_executions: Optional[int] = None, + overlap_gpu_communication: Optional[bool] = None, ): """ Args: @@ -613,6 +706,11 @@ def __init__( are allowed to be sent to this DAG. Before submitting more requests, the caller is responsible for calling ray.get to get the result, otherwise, RayAdagCapacityExceeded is raised. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution. If True, the communication + and computation can be overlapped, which can improve the + performance of the DAG execution. If None, the default value + will be used. Returns: Channel: A wrapper around ray.ObjectRef. @@ -638,6 +736,9 @@ def __init__( self._buffer_size_bytes: Optional[int] = buffer_size_bytes if self._buffer_size_bytes is None: self._buffer_size_bytes = ctx.buffer_size_bytes + self._overlap_gpu_communication: Optional[bool] = overlap_gpu_communication + if self._overlap_gpu_communication is None: + self._overlap_gpu_communication = ctx.overlap_gpu_communication self._default_type_hint: ChannelOutputType = SharedMemoryType( buffer_size_bytes=self._buffer_size_bytes, @@ -995,7 +1096,7 @@ def _preprocess(self) -> None: "Expected P2P actor handles to be a subset of the custom NCCL group" ) self._nccl_group_id_p2p = _init_nccl_group( - nccl_actors_p2p, self._custom_nccl_group_p2p + nccl_actors_p2p, self._custom_nccl_group_p2p, self._overlap_gpu_communication ) custom_nccl_group_to_id[ self._custom_nccl_group_p2p @@ -1025,7 +1126,7 @@ def _preprocess(self) -> None: self._nccl_group_id_p2p = actors_to_nccl_group_id[actors] else: self._nccl_group_id_p2p = _init_nccl_group( - nccl_actors_p2p, self._custom_nccl_group_p2p + nccl_actors_p2p, self._custom_nccl_group_p2p, self._overlap_gpu_communication ) actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p @@ -1434,6 +1535,7 @@ def _get_or_compile( exec_task_func, executable_tasks, self.actor_to_execution_schedule[actor_handle], + self._overlap_gpu_communication, ) assert self.output_task_idx is not None @@ -1522,26 +1624,38 @@ def _generate_dag_operation_graph_node( for exec_task_idx, exec_task in enumerate(executable_tasks): # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE, # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. - task_idx = exec_task.task_idx - dag_node = self.idx_to_task[task_idx].dag_node + task_index = exec_task.task_idx + dag_node = self.idx_to_task[task_index].dag_node + method_name = exec_task.method_name actor_handle = dag_node._get_actor_handle() requires_nccl = dag_node.type_hint.requires_nccl() + upstream_requires_nccl = False + for upstream_node in dag_node._upstream_nodes: + if upstream_node.type_hint.requires_nccl(): + upstream_requires_nccl = True + break read_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.READ), - task_idx, + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.READ, method_name + ), + task_index, actor_handle, - requires_nccl, + upstream_requires_nccl, ) compute_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.COMPUTE), - task_idx, + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name + ), + task_index, actor_handle, - isinstance(dag_node, CollectiveOutputNode), + False, ) write_node = _DAGOperationGraphNode( - _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.WRITE), - task_idx, + _DAGNodeOperation( + exec_task_idx, _DAGNodeOperationType.WRITE, method_name + ), + task_index, actor_handle, requires_nccl, ) @@ -1599,7 +1713,23 @@ def _build_execution_schedule( ) # Step 2: Generate an execution schedule for each actor using topological sort actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) - return actor_to_execution_schedule + + # Step 3: Optimize the execution schedule if configured + if self._optimize_execution_schedule: + actor_to_overlapped_schedule = _generate_overlapped_execution_schedule( + actor_to_execution_schedule + ) + else: + actor_to_overlapped_schedule = None + + from ray.dag.constants import RAY_ADAG_VISUALIZE_SCHEDULE + + if RAY_ADAG_VISUALIZE_SCHEDULE: + _visualize_execution_schedule( + actor_to_execution_schedule, actor_to_overlapped_schedule, graph + ) + + return _extract_execution_schedule(actor_to_overlapped_schedule) def _detect_deadlock(self) -> bool: """ @@ -2205,6 +2335,7 @@ def build_compiled_dag_from_ray_dag( asyncio_max_queue_size: Optional[int] = None, max_buffered_results: Optional[int] = None, max_inflight_executions: Optional[int] = None, + overlap_gpu_communication: Optional[bool] = None, ) -> "CompiledDAG": compiled_dag = CompiledDAG( execution_timeout, @@ -2213,6 +2344,7 @@ def build_compiled_dag_from_ray_dag( asyncio_max_queue_size, max_buffered_results, max_inflight_executions, + overlap_gpu_communication, ) def _build_compiled_dag(node): diff --git a/python/ray/dag/constants.py b/python/ray/dag/constants.py index ed86adf914ed..bc268412a934 100644 --- a/python/ray/dag/constants.py +++ b/python/ray/dag/constants.py @@ -19,3 +19,6 @@ # Feature flag to turn on profiling. RAY_ADAG_ENABLE_PROFILING = os.environ.get("RAY_ADAG_ENABLE_PROFILING", "0") == "1" + +# Feature flag to turn on visualization of the execution schedule. +RAY_ADAG_VISUALIZE_SCHEDULE = os.environ.get("RAY_ADAG_VISUALIZE_SCHEDULE", "0") == "1" diff --git a/python/ray/dag/context.py b/python/ray/dag/context.py index 03eb8915d13b..29e1d5bf2c78 100644 --- a/python/ray/dag/context.py +++ b/python/ray/dag/context.py @@ -25,6 +25,10 @@ os.environ.get("RAY_DAG_max_inflight_executions", 10) ) +DEFAULT_OVERLAP_GPU_COMMUNICATION = bool( + os.environ.get("RAY_DAG_overlap_gpu_communication", 0) +) + @DeveloperAPI @dataclass @@ -58,6 +62,12 @@ class DAGContext: executions is beyond the DAG capacity, the new execution would be blocked in the first place; therefore, this limit is only enforced when it is smaller than the DAG capacity. + max_inflight_executions: The maximum number of in-flight executions + that can be submitted before consuming the output. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution. If True, the communication + and computation can be overlapped, which can improve the + performance of the DAG execution. """ execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT_S @@ -66,6 +76,7 @@ class DAGContext: asyncio_max_queue_size: int = DEFAULT_ASYNCIO_MAX_QUEUE_SIZE max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS + overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION @staticmethod def get_current() -> "DAGContext": diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 320fe392bb4e..ac4802fd8244 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -158,6 +158,7 @@ def experimental_compile( _asyncio_max_queue_size: Optional[int] = None, _max_buffered_results: Optional[int] = None, _max_inflight_executions: Optional[int] = None, + _overlap_gpu_communication: Optional[bool] = None, ) -> "ray.dag.CompiledDAG": """Compile an accelerated execution path for this DAG. @@ -182,6 +183,11 @@ def experimental_compile( are allowed to be sent to this DAG. Before submitting more requests, the caller is responsible for calling ray.get to clear finished in-flight requests. + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution. If True, the communication + and computation can be overlapped, which can improve the + performance of the DAG execution. If None, the default value + will be used. Returns: A compiled DAG. @@ -215,6 +221,7 @@ def experimental_compile( _asyncio_max_queue_size, _max_buffered_results, _max_inflight_executions, + _overlap_gpu_communication, ) def execute( diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 8f03547702be..d9cb3fe7c2e6 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -1,11 +1,17 @@ from functools import total_ordering from enum import Enum from typing import Set, Tuple, List, Dict, Optional +from typing import Optional, Tuple, List, Dict +import copy +import logging import ray import heapq from collections import defaultdict +logger = logging.getLogger(__name__) + + class _DAGNodeOperationType(Enum): """ There are three types of operations that a DAG node can perform: @@ -18,12 +24,22 @@ class _DAGNodeOperationType(Enum): COMPUTE = "COMPUTE" WRITE = "WRITE" + def __str__(self): + if self == _DAGNodeOperationType.READ: + return "R" + elif self == _DAGNodeOperationType.COMPUTE: + return "C" + elif self == _DAGNodeOperationType.WRITE: + return "W" + assert False, f"Unknown operation type: {self}" + class _DAGNodeOperation: def __init__( self, exec_task_idx: int, operation_type: _DAGNodeOperationType, + method_name: Optional[str] = None, ): """ Args: @@ -32,9 +48,12 @@ def __init__( as bind_index because there may be more tasks bound to an actor than tasks that appear in the current compiled DAG. operation_type: The type of operation to perform. + method_name: The name of the method that this operation originates + from. This is only for debugging purposes. """ self.exec_task_idx = exec_task_idx self.type = operation_type + self.method_name = method_name def __repr__(self): return ( @@ -43,6 +62,17 @@ def __repr__(self): f" type: {self.type})" ) + def __str__(self): + return f"([{self.exec_task_idx}] {self.method_name} {self.type})" + + def __hash__(self): + return hash((self.exec_task_idx, self.type)) + + def __eq__(self, other): + # An operation is uniquely identified by its `exec_task_idx` and type. + # `func_name` is only for debugging purposes. + return self.exec_task_idx == other.exec_task_idx and self.type == other.type + @total_ordering class _DAGOperationGraphNode: @@ -70,12 +100,14 @@ def __init__( self.task_idx = task_idx self.actor_handle = actor_handle self.requires_nccl = requires_nccl - # The in_edges and out_edges are sets of tuples. Each tuple contains - # an integer `task_idx`, which can be used to index into `idx_to_task` - # to get the corresponding task, and a `_DAGNodeOperationType`, which can - # be READ, COMPUTE, or WRITE. - self.in_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() - self.out_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() + # The in_edges and out_edges are dicts of tuples to strings. + # Each tuple (the key) contains an integer `task_idx`, which can be + # used to index into `idx_to_task` to get the corresponding task, + # and a `_DAGNodeOperationType`, which can be READ, COMPUTE, or WRITE. + # The string (the value) is the label of the edge, which will be used + # to annotate the edge in the visualization of the execution schedule. + self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], str] = {} + self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], str] = {} # The collective nodes are the nodes that belong to the same collective # operation. Each node is represented by a tuple of its task idx and type. self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() @@ -179,14 +211,40 @@ def is_nccl_write(self) -> bool: def is_nccl_op(self) -> bool: return self.is_nccl_collective or self.is_nccl_write + def __str__(self): + class_name = ( + self.actor_handle._ray_actor_creation_function_descriptor.class_name + ) + actor_id = self._actor_id.hex() + actor_id_abbv = actor_id[:4] + "..." + return ( + class_name + + "_" + + actor_id_abbv + + f" [{self.operation.exec_task_idx}] " + + f"{self.operation.method_name} {self.operation.type}" + ) + + @property + def _actor_id(self): + return self.actor_handle._ray_actor_id.hex() -def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode): + +def _add_edge( + from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode, label: str = "" +): """ Add an edge from `from_node` to `to_node`. An edge is a tuple of the operation's `task_idx` and type. + + Args: + from_node: The node from which the edge originates. + to_node: The node to which the edge points. + label: The label of the edge. This will be used to annotate the edge + in the visualization of the execution schedule. """ - from_node.out_edges.add((to_node.task_idx, to_node.operation.type)) - to_node.in_edges.add((from_node.task_idx, from_node.operation.type)) + from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = label + to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = label def _push_candidate_node_if_ready( @@ -343,7 +401,7 @@ def _build_dag_node_operation_graph( # Add an edge from COMPUTE with `bind_index` i to COMPUTE with # `bind_index` i+1 if they belong to the same actor. if prev_compute_node is not None: - _add_edge(prev_compute_node, compute_node) + _add_edge(prev_compute_node, compute_node, "next") prev_compute_node = compute_node assert task_idx not in graph graph[task_idx] = { @@ -385,18 +443,102 @@ def _build_dag_node_operation_graph( _add_edge( graph[task_idx][_DAGNodeOperationType.WRITE], graph[downstream_task_idx][_DAGNodeOperationType.READ], + "nccl" + if graph[task_idx][_DAGNodeOperationType.WRITE].requires_nccl + else "shm", ) return graph +def _node_repr(node: _DAGOperationGraphNode, idx: int, optimized_index: int): + """ + Representation of a node in the visualization of the execution schedule. + + Args: + node: The node to be represented. + idx: The index of the node in the execution schedule. + optimized_index: The index of the node in the optimized execution schedule. + """ + return str(node) + f" {idx},{optimized_index}" + + +def _visualize_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ], + actor_to_overlapped_schedule: Optional[ + Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]] + ], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], +): + """ + Visualize the execution schedule for each actor. + + Args: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operation nodes. + actor_to_overlapped_schedule: A dictionary that maps an actor handle to the + optimized execution schedule which is a list of operation nodes. + graph: A graph where each node is a _DAGOperationGraphNode. The key is + `task_idx`, the index to retrieve its task from `idx_to_task`, and + the value is a dictionary that maps the _DAGNodeOperationType (READ, + COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is + generated by `_build_dag_node_operation_graph`. + """ + try: + import graphviz + except ImportError: + raise ImportError( + "Please install graphviz to visualize the execution schedule. " + "You can install it by running `pip install graphviz`." + ) + + dot = graphviz.Digraph(comment="DAG") + node_to_repr: Dict[_DAGOperationGraphNode, str] = {} + + # TODO: only visualize the execution schedule if the overlapped schedule is None. + if actor_to_overlapped_schedule is None: + actor_to_overlapped_schedule = actor_to_execution_schedule + for actor, execution_nodes in actor_to_execution_schedule.items(): + overlapped_schedule = actor_to_overlapped_schedule[actor] + node_to_optimized_index = { + node: i for i, node in enumerate(overlapped_schedule) + } + + with dot.subgraph(name=f"cluster_{execution_nodes[0]._actor_id}") as subgraph: + subgraph.attr(rank=execution_nodes[0]._actor_id) + for i, node in enumerate(execution_nodes): + optimized_index = node_to_optimized_index.get(node) + node_repr = _node_repr(node, i, optimized_index) + color = "red" if optimized_index != i else "black" + subgraph.node(node_repr, node_repr, color=color) + node_to_repr[node] = node_repr + + for actor, execution_nodes in actor_to_execution_schedule.items(): + for i, node in enumerate(execution_nodes): + node_repr = node_to_repr[node] + for out_edge, label in node.out_edges.items(): + out_task_idx, out_op_type = out_edge + out_node = graph[out_task_idx][out_op_type] + out_node_repr = node_to_repr[out_node] + color = "blue" if label == "nccl" else "black" + dot.edge(node_repr, out_node_repr, label=label, color=color) + + logger.info( + "Writing compiled graph schedule visualization " + "to compiled_graph_schedule.png" + ) + dot.render("compiled_graph_schedule", format="png", view=False) + + def _generate_actor_to_execution_schedule( graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] -) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]: +) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]: """ Generate an execution schedule for each actor. The schedule is a list of - operations to be executed. The function uses a topological sort algorithm - to generate the schedule. + operation nodes to be executed. The function uses a topological sort + algorithm to generate the schedule. Args: graph: A graph where each node is a _DAGOperationGraphNode. The key is @@ -407,13 +549,14 @@ def _generate_actor_to_execution_schedule( Returns: actor_to_execution_schedule: A dictionary that maps an actor handle to - the execution schedule which is a list of operations to be executed. + the execution schedule which is a list of operation nodes to be + executed. """ # Mapping from the actor handle to the execution schedule which is a list # of operations to be executed. actor_to_execution_schedule: Dict[ - "ray.actor.ActorHandle", List[_DAGNodeOperation] + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] ] = defaultdict(list) # A dictionary mapping an actor id to a list of candidate nodes. The list @@ -450,13 +593,13 @@ def _generate_actor_to_execution_schedule( nodes = [node for node in nodes if node not in visited_nodes] # Add the selected nodes to the execution schedule. for node in nodes: - actor_to_execution_schedule[node.actor_handle].append(node.operation) + actor_to_execution_schedule[node.actor_handle].append(node) visited_nodes.add(node) # Update the in-degree of the downstream nodes. for node in nodes: for out_node_task_idx, out_node_type in node.out_edges: out_node = graph[out_node_task_idx][out_node_type] - out_node.in_edges.remove((node.task_idx, node.operation.type)) + out_node.in_edges.pop((node.task_idx, node.operation.type)) if out_node.in_degree == 0 and out_node not in visited_nodes: # If the downstream node is already visited, it has been added # to the execution schedule. They are the NCCL read nodes in @@ -469,3 +612,77 @@ def _generate_actor_to_execution_schedule( assert len(candidates) == 0, "Expected all candidates to be empty" return actor_to_execution_schedule + + +def _generate_overlapped_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ], +) -> Dict["ray.actor.ActorHandle", List[_DAGOperationGraphNode]]: + """ + From an existing execution schedule, generate a new schedule by overlapping + computation and communication. + + Currently, the algorithm generates a new schedule for each actor as follows: + For each NCCL read operation (i.e., recv), scan backwards to find the nearest + compute node to swap with so that the NCCL read operation can be overlapped + with computation. + + Args: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the existing execution schedule for the actor. The schedule is a list + is a list of operations to be executed. + + Returns: + A dictionary that maps an actor handle to the overlapped execution schedule + for the actor. + """ + + actor_to_overlapped_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ] = copy.deepcopy(actor_to_execution_schedule) + for overlapped_schedule in actor_to_overlapped_schedule.values(): + for i in range(len(overlapped_schedule)): + if ( + overlapped_schedule[i].operation.type == _DAGNodeOperationType.READ + and overlapped_schedule[i].requires_nccl + ): + # For each NCCL read operation (i.e., recv), scan backwards + # to find the nearest compute node to swap with so that + # the NCCL read operation can be overlapped with computation. + for j in range(i - 1, -1, -1): + if ( + overlapped_schedule[j].operation.type + == _DAGNodeOperationType.COMPUTE + ): + # Found a desired compute operation, make the swap + nccl_read_op = overlapped_schedule[i] + prev_ops = overlapped_schedule[j:i] + overlapped_schedule[j + 1 : i + 1] = prev_ops + overlapped_schedule[j] = nccl_read_op + break + if ( + overlapped_schedule[j].operation.type + == _DAGNodeOperationType.READ + or overlapped_schedule[j].operation.type + == _DAGNodeOperationType.WRITE + ) and overlapped_schedule[j].requires_nccl: + # Found a NCCL read/write operation, skip the overlap + # optimization to keep relative order of NCCL operations + break + return actor_to_overlapped_schedule + + +def _extract_execution_schedule( + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGOperationGraphNode] + ] +) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]: + """ + Extract _DAGNodeOperation from _DAGOperationGraphNode in the schedule + and discard unnecessary information. + """ + return { + actor: [node.operation for node in nodes] + for actor, nodes in actor_to_execution_schedule.items() + } diff --git a/python/ray/dag/dag_operation_future.py b/python/ray/dag/dag_operation_future.py new file mode 100644 index 000000000000..8ab165cb09fd --- /dev/null +++ b/python/ray/dag/dag_operation_future.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from ray.util.annotations import DeveloperAPI + + +if TYPE_CHECKING: + import torch + import cupy as cp + +T = TypeVar("T") + + +@DeveloperAPI +class DAGOperationFuture(ABC, Generic[T]): + """ + A future representing the result of a DAG operation. + + This is an abstraction that is internal to each actor, + and is not exposed to the DAG caller. + """ + + @abstractmethod + def wait(self): + """ + Wait for the future and return the result of the operation. + """ + raise NotImplementedError + + +@DeveloperAPI +class ResolvedFuture(DAGOperationFuture): + """ + A future that is already resolved. Calling `wait()` on this will + immediately return the result without blocking. + """ + + def __init__(self, result): + """ + Initialize a resolved future. + + Args: + result: The result of the future. + """ + self._result = result + + def wait(self): + """ + Wait and immediately return the result. This operation will not block. + """ + return self._result + + +@DeveloperAPI +class GPUFuture(DAGOperationFuture["torch.Tensor"]): + """ + A future that represents a GPU operation. + """ + + def __init__(self, buf: "torch.Tensor", stream: Optional["cp.cuda.Stream"] = None): + """ + Initialize a GPU future. + + Args: + buf: The buffer to return when the future is resolved. + stream: The CUDA stream to record the event on. If None, the current + stream is used. + """ + import cupy as cp + + if stream is None: + stream = cp.cuda.get_current_stream() + + self._buf = buf + self._event = cp.cuda.Event() + self._event.record(stream) + + def wait(self) -> "torch.Tensor": + """ + Wait for the future and return the result from the GPU operation. + """ + import cupy as cp + + if self._event is not None: + current_stream = cp.cuda.get_current_stream() + current_stream.wait_event(self._event) + return self._buf diff --git a/python/ray/dag/tests/experimental/test_execution_schedule.py b/python/ray/dag/tests/experimental/test_execution_schedule.py index 34c8777f0be4..ffb14e78e60c 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule.py @@ -10,6 +10,7 @@ _DAGNodeOperationType, _DAGOperationGraphNode, _DAGNodeOperation, + _extract_execution_schedule, _select_next_nodes, _build_dag_node_operation_graph, _add_edge, @@ -73,6 +74,9 @@ def set_ready_collective_idxs( _DAGNodeOperationType.COMPUTE ].ready_collective_idxs = ready_collective_idxs +def _generate_and_extract_execution_schedule(graph): + return _extract_execution_schedule(_generate_actor_to_execution_schedule(graph)) + class TestSelectNextNodes: """ @@ -718,7 +722,7 @@ def test_single_actor_1(self, monkeypatch): self.add_data_dependeny(graph[task_idx_1], graph[task_idx_2]) self.add_control_dependency(graph[task_idx_1], graph[task_idx_2]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 1 assert len(actor_to_execution_schedule[fake_actor]) == 6 assert actor_to_execution_schedule[fake_actor] == [ @@ -767,7 +771,7 @@ def test_single_actor_2(self, monkeypatch): self.add_control_dependency(graph[task_idx_1], graph[task_idx_2]) self.add_control_dependency(graph[task_idx_2], graph[task_idx_3]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 1 assert len(actor_to_execution_schedule[fake_actor]) == 9 assert actor_to_execution_schedule[fake_actor] == [ @@ -826,7 +830,7 @@ def test_two_actors_no_nccl(self, monkeypatch): self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[fake_actor_1]) == 6 assert len(actor_to_execution_schedule[fake_actor_2]) == 6 @@ -891,7 +895,7 @@ def test_two_actors_with_nccl(self, monkeypatch): self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[fake_actor_1]) == 6 assert len(actor_to_execution_schedule[fake_actor_2]) == 6 @@ -984,7 +988,7 @@ def test_simulate_pp_2workers_2batches_1f1b_with_nccl(self, monkeypatch): self.add_control_dependency(graph[task_idx_2_2], graph[task_idx_2_3]) self.add_control_dependency(graph[task_idx_2_3], graph[task_idx_2_4]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[worker_1]) == 12 assert len(actor_to_execution_schedule[worker_2]) == 12 @@ -1091,7 +1095,7 @@ def test_simulate_pp_2workers_2batches_1f1b_no_nccl(self, monkeypatch): self.add_control_dependency(graph[task_idx_2_2], graph[task_idx_2_3]) self.add_control_dependency(graph[task_idx_2_3], graph[task_idx_2_4]) - actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + actor_to_execution_schedule = _generate_and_extract_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[worker_1]) == 12 assert len(actor_to_execution_schedule[worker_2]) == 12 diff --git a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py index bdb9dafbdd62..12eeb7e4f208 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py @@ -42,6 +42,16 @@ def pop_trace(self): def read_input(self, input): return input + def send(self, shape, dtype, value: int, send_tensor=True): + if not send_tensor: + return 1 + return torch.ones(shape, dtype=dtype, device=self.device) * value + + def recv(self, tensor): + # Check that tensor got loaded to the correct device. + assert tensor.device == self.device + return (tensor[0].item(), tensor.shape, tensor.dtype) + def no_op(self, value): return value @@ -365,6 +375,71 @@ def test_three_actors_with_nccl_2(ray_start_regular, single_fetch, monkeypatch): assert torch.equal(tensor, tensor_cpu) +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) +@pytest.mark.parametrize("overlap_gpu_communication", [True, False]) +def test_overlap_gpu_communication(ray_start_regular, overlap_gpu_communication): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + sender1 = Worker.remote() + sender2 = Worker.remote() + receiver = Worker.remote() + + shape = (10000,) + dtype = torch.float16 + + with InputNode() as inp: + branch1 = sender1.send.bind(shape, dtype, inp) + + branch1 = branch1.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + branch1 = receiver.recv.bind(branch1) + + branch2 = sender2.send.bind(shape, dtype, inp) + branch2 = branch2.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + branch2 = receiver.recv.bind(branch2) + dag = MultiOutputNode([branch1, branch2]) + + # Test normal execution. + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) + + # Check receiver schedule + expected_no_overlap_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + expected_overlap_schedule = [ + (0, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + if overlap_gpu_communication: + expected_receiver_schedule = expected_overlap_schedule + else: + expected_receiver_schedule = expected_no_overlap_schedule + + receiver_schedule = compiled_dag.actor_to_execution_schedule[receiver] + + assert len(receiver_schedule) == len(expected_receiver_schedule) + for i, operation in enumerate(receiver_schedule): + assert operation.exec_task_idx == expected_receiver_schedule[i][0] + assert operation.type == expected_receiver_schedule[i][1] + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 5d4f06e528e7..dcf4e15e6278 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -13,6 +13,8 @@ from ray.air._internal import torch_utils from ray.dag import InputNode, MultiOutputNode from ray.exceptions import RayChannelError +from python.ray.dag.compiled_dag_node import GPUFuture +from ray.dag.output_node import MultiOutputNode from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, TorchTensorAllocator, @@ -63,6 +65,20 @@ def recv(self, tensor): assert tensor.device == self.device return (tensor[0].item(), tensor.shape, tensor.dtype) + def recv_and_matmul(self, two_d_tensor): + """ + Receive the tensor and do some expensive computation (matmul). + + Args: + two_d_tensor: a 2D tensor that has the same size for its dimensions + """ + # Check that tensor got loaded to the correct device. + assert two_d_tensor.dim() == 2 + assert two_d_tensor.size(0) == two_d_tensor.size(1) + assert two_d_tensor.device == self.device + torch.matmul(two_d_tensor, two_d_tensor) + return (two_d_tensor[0][0].item(), two_d_tensor.shape, two_d_tensor.dtype) + def recv_dict(self, tensor_dict): vals = {} for i, tensor in tensor_dict.items(): @@ -212,6 +228,104 @@ def test_torch_tensor_nccl(ray_start_regular): @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_overlap(ray_start_regular, monkeypatch): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 3 + ), "This test requires at least 3 GPUs" + + original_gpu_future_init = GPUFuture.__init__ + def mock_gpu_future_init(self, *args, **kwargs): + init_ts = time.monotonic() + original_gpu_future_init(self, *args, **kwargs) + + monkeypatch.setattr(GPUFuture, "__init__", mock_gpu_future_init) + + worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + sender1 = worker_cls.remote() + sender2 = worker_cls.remote() + receiver = worker_cls.remote() + + shape = (10000, ) + dtype = torch.float16 + + with InputNode() as inp: + branches = [sender.send.bind(shape, dtype, inp) for sender in senders] + branches = [ + branch.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + for branch in branches + ] + branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] + dag = MultiOutputNode(branches) + + # Test normal execution. + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) + + start = time.monotonic() + for i in range(5): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == [(i, shape, dtype)] * num_senders + duration = time.monotonic() - start + print(f"{overlap_gpu_communication=}, {duration=}") + + compiled_dag.teardown() + + +@pytest.mark.parametrize( + "ray_start_regular, overlap_gpu_communication", + [({"num_cpus": 4}, False), ({"num_cpus": 4}, True)], + indirect=["ray_start_regular"], +) +def test_torch_tensor_nccl_overlap_timed(ray_start_regular, overlap_gpu_communication): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 4 + ), "This test requires at least 4 GPUs" + + worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + num_senders = 3 + senders = [worker_cls.remote() for _ in range(num_senders)] + receiver = worker_cls.remote() + + shape = (10000, 10000) + dtype = torch.float16 + + with InputNode() as inp: + branches = [sender.send.bind(shape, dtype, inp) for sender in senders] + branches = [ + branch.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + for branch in branches + ] + branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] + dag = MultiOutputNode(branches) + + # Test normal execution. + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) + + start = time.monotonic() + for i in range(5): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == [(i, shape, dtype)] * num_senders + duration = time.monotonic() - start + print(f"{overlap_gpu_communication=}, {duration=}") + + compiled_dag.teardown() + + def test_torch_tensor_nccl_disallows_driver(ray_start_regular): """ Check that the driver cannot participate in the NCCL group, i.e. DAG input @@ -272,6 +386,74 @@ def test_torch_tensor_custom_comm(ray_start_regular): sender = actor_cls.remote() receiver = actor_cls.remote() + class TestNcclGroup(GPUCommunicator): + """ + A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. + """ + + import cupy as cp + + def __init__(self, world_size, comm_id, actor_handles): + self._world_size = world_size + self._comm_id = comm_id + self._actor_handles = actor_handles + self._inner = None + + def initialize(self, rank: int) -> None: + self._inner = _NcclGroup( + self._world_size, + self._comm_id, + rank, + self._actor_handles, + torch.cuda.current_stream().cuda_stream, + ) + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + return self._world_size + + def get_self_rank(self) -> Optional[int]: + if self._inner is None: + return None + return self._inner.get_self_rank() + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + return self._inner.send(value, peer_rank) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) + + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._inner.recv_stream + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._inner.send_stream + + def destroy(self) -> None: + return self._inner.destroy() + from cupy.cuda import nccl class TestNcclGroup(GPUCommunicator): @@ -382,6 +564,8 @@ class MockNcclGroup(GPUCommunicator): A mock NCCL group for testing. Send and recv are not implemented. """ + import cupy as cp + def __init__(self, world_size, actor_handles): self._world_size = world_size self._actor_handles = actor_handles @@ -432,6 +616,14 @@ def allreduce( ) -> None: raise NotImplementedError + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + def destroy(self) -> None: pass @@ -521,6 +713,8 @@ class InitedNcclGroup(GPUCommunicator): A custom NCCL group based on existing torch.distributed setup. """ + import cupy as cp + def __init__(self, world_size, actor_handles): self._world_size = world_size self._actor_handles = actor_handles @@ -573,6 +767,18 @@ def allreduce( ) -> None: raise NotImplementedError + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + import cupy as cp + + return cp.cuda.get_current_stream() + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + import cupy as cp + + return cp.cuda.get_current_stream() + def destroy(self) -> None: pass @@ -708,6 +914,57 @@ def test_torch_tensor_nccl_nested_dynamic(ray_start_regular): @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_direct_return_error(ray_start_regular): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + sender = actor_cls.remote() + receiver = actor_cls.remote() + + shape = (10,) + dtype = torch.float16 + + # Passing a non-tensor value when _direct_return=True and tranport="nccl" + # fails. + with InputNode() as inp: + dag = sender.send.bind(inp.shape, inp.dtype, inp.value, inp.send_tensor) + dag = dag.with_type_hint( + TorchTensorType( + transport=TorchTensorType.NCCL, + _direct_return=True, + ) + ) + dag = receiver.recv.bind(dag) + + compiled_dag = dag.experimental_compile() + + ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=True) + assert ray.get(ref) == (1, shape, dtype) + + ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=False) + with pytest.raises(RayChannelError): + ray.get(ref) + + # For direct_return=True tensors, the DAG will be torn down after any task + # throws an application-level exception, such as when the task returns + # something other than a torch.Tensor. Check that we can no longer submit + # to the DAG. + with pytest.raises(RayChannelError): + ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=True) + + compiled_dag.teardown() + + # TODO(swang): This currently requires time.sleep to avoid some issue with + # following tests. + time.sleep(3) + + @pytest.mark.parametrize("static_shape", [False, True]) @pytest.mark.parametrize("direct_return", [False, True]) def test_torch_tensor_exceptions(ray_start_regular, static_shape, direct_return): @@ -739,7 +996,9 @@ def test_torch_tensor_exceptions(ray_start_regular, static_shape, direct_return) ) dag = receiver.recv.bind(dag) - compiled_dag = dag.experimental_compile() + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) shape = (10,) dtype = torch.float16 diff --git a/python/ray/experimental/channel/common.py b/python/ray/experimental/channel/common.py index 8b4d4920e377..55715eefc856 100644 --- a/python/ray/experimental/channel/common.py +++ b/python/ray/experimental/channel/common.py @@ -121,6 +121,7 @@ def set_nccl_group_id(self, group_id: str) -> None: class ChannelContext: serialization_context = _SerializationContext() _torch_device: Optional["torch.device"] = None + _current_stream: Optional["torch.cuda.Stream"] = None def __init__(self): # Used for the torch.Tensor NCCL transport. diff --git a/python/ray/experimental/channel/gpu_communicator.py b/python/ray/experimental/channel/gpu_communicator.py index 26cae2ff9409..acb64c9e5da1 100644 --- a/python/ray/experimental/channel/gpu_communicator.py +++ b/python/ray/experimental/channel/gpu_communicator.py @@ -6,6 +6,7 @@ from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: + import cupy as cp import torch @@ -79,6 +80,7 @@ def send(self, value: "torch.Tensor", peer_rank: int) -> None: value: The torch.Tensor to send. It should already be on this actor's default device. peer_rank: The rank of the actor to send to. + future: An optional future to wait on before sending. """ raise NotImplementedError @@ -105,6 +107,22 @@ def recv( """ raise NotImplementedError + @property + @abstractmethod + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + """ + Return the cuda stream used for receiving tensors. + """ + raise NotImplementedError + + @property + @abstractmethod + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + """ + Return the cuda stream used for sending tensors. + """ + raise NotImplementedError + @abstractmethod def allreduce( self, diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index 8f4804848323..00d8dcce8eb5 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -36,6 +36,7 @@ def __init__( rank: Optional[int], actor_handles: List["ray.actor.ActorHandle"], cuda_stream: Optional[int], + use_communication_streams: bool = False, ): """ Initialize a NCCL communicator that can be used to communicate p2p with @@ -67,11 +68,15 @@ def __init__( actor_handles: A list of actor handles, in rank order. cuda_stream: A raw CUDA stream to dispatch NCCL ops to. If rank is specified, then this must be specified too. + use_communication_streams: Whether to use dedicated send and recv + streams for communication. If True, communication and computation + can be overlapped to improve perfomrance. """ self._world_size = world_size self._rank: Optional[int] = rank self.nccl_util: Optional[ModuleType] = None self._actor_handles = actor_handles + self._use_communication_streams = use_communication_streams if rank is not None: assert ray.get_gpu_ids(), "NCCL actor has no GPUs assigned" @@ -91,6 +96,8 @@ def __init__( self._comm = None self._cuda_stream: Optional["cp.cuda.ExternalStream"] = None + self._send_stream: Optional["cp.cuda.ExternalStream"] = None + self._recv_stream: Optional["cp.cuda.ExternalStream"] = None if cuda_stream is not None: assert rank is not None, "NCCL actor has no rank assigned" @@ -104,6 +111,19 @@ def __init__( cuda_stream, device_id=device.index ) + if use_communication_streams: + import torch + + self._send_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + self._recv_stream = cp.cuda.ExternalStream( + torch.cuda.Stream().cuda_stream, device_id=device.index + ) + else: + self._send_stream = self._cuda_stream + self._recv_stream = self._cuda_stream + self._closed = False def initialize(self, rank: int) -> None: @@ -157,6 +177,15 @@ def send(self, buf: "torch.Tensor", peer_rank: int) -> None: """ if self._closed: raise RayChannelError("NCCL group has been destroyed.") + + if self._use_communication_streams: + # We observed that if all recv/compute/send operations run on GPU, + # since there is no synchronization, the CPU execution loop may be + # far ahead of the GPU operations and lead to runtime failures. + # To avoid that, we synchronize on the send stream. + # TODO(rui): find a better approach + self._send_stream.synchronize() + # TODO(swang): Handle send/recv async NCCL errors such as network # failures. self._comm.send( @@ -164,7 +193,7 @@ def send(self, buf: "torch.Tensor", peer_rank: int) -> None: buf.numel(), self.nccl_util.get_nccl_tensor_dtype(buf), peer_rank, - self._cuda_stream.ptr, + self._send_stream.ptr, ) def recv( @@ -189,19 +218,37 @@ def recv( raise RayChannelError("NCCL group has been destroyed.") assert allocator is not None, "NCCL group requires a tensor allocator" buf = allocator(shape, dtype) - self._comm.recv( - self.nccl_util.get_tensor_ptr(buf), - buf.numel(), - self.nccl_util.get_nccl_tensor_dtype(buf), - peer_rank, - self._cuda_stream.ptr, - ) - # Buffer values are undefined if NCCL ops are aborted. Therefore, we - # need to synchronize here and check that the channel is still open to - # ensure that the receive buffer is valid. - # TODO(swang): Avoid CUDA synchronization. - self._cuda_stream.synchronize() + if self._use_communication_streams: + # We observed that if all recv/compute/send operations run on GPU, + # since there is no synchronization, the CPU execution loop may be + # far ahead of the GPU operations and lead to runtime failures. + # To avoid that, we synchronize on the recv stream. + # TODO(rui): find a better approach + self._recv_stream.synchronize() + + self._comm.recv( + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), + peer_rank, + self._recv_stream.ptr, + ) + else: + self._comm.recv( + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), + peer_rank, + self._recv_stream.ptr, + ) + + # Buffer values are undefined if NCCL ops are aborted. Therefore, we + # need to synchronize here and check that the channel is still open to + # ensure that the receive buffer is valid. + # TODO(swang): Avoid CUDA synchronization. + self._cuda_stream.synchronize() + if self._closed: raise RayChannelError("NCCL group has been destroyed.") return buf @@ -233,6 +280,14 @@ def allreduce( if self._closed: raise RayChannelError("NCCL group has been destroyed.") + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._recv_stream + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._send_stream + def destroy(self) -> None: """ Destroy the NCCL group. diff --git a/python/ray/experimental/channel/torch_tensor_nccl_channel.py b/python/ray/experimental/channel/torch_tensor_nccl_channel.py index 5abec1ab409b..4b8baedeebd0 100644 --- a/python/ray/experimental/channel/torch_tensor_nccl_channel.py +++ b/python/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -548,6 +548,7 @@ def _do_init_nccl_group( comm_id, rank, actor_handles, + use_communication_streams, custom_nccl_group: Optional[GPUCommunicator] = None, ): import torch @@ -567,6 +568,7 @@ def _do_init_nccl_group( rank, actor_handles, torch.cuda.current_stream().cuda_stream, + use_communication_streams, ) @@ -627,6 +629,7 @@ def _get_ranks( def _init_nccl_group( actors: List[ray.actor.ActorHandle], custom_nccl_group: Optional[GPUCommunicator] = None, + use_communication_streams: bool = False, ) -> str: """ Initialize a NCCL group with the given actors. If a custom NCCL group is @@ -635,6 +638,9 @@ def _init_nccl_group( Args: actors: A list of actors that participate in the NCCL group. custom_nccl_group: A custom NCCL group to initialize. + use_communication_streams: Whether to use dedicated send and recv + streams for communication. If True, communication and computation + can be overlapped to improve perfomrance. """ ctx = ChannelContext.get_current() @@ -675,6 +681,7 @@ def _init_nccl_group( nccl_comm_id, rank, actors, + use_communication_streams, custom_nccl_group, ) for rank, actor in zip(ranks, actors) From f4258d98421bf6d1bced1f3caa04fa69d451d9a6 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Sun, 27 Oct 2024 12:01:58 -0700 Subject: [PATCH 02/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 46 +++++----- python/ray/dag/dag_node_operation.py | 1 - .../experimental/test_execution_schedule.py | 1 + .../experimental/test_mocked_nccl_dag.py | 55 +++++++++++- .../experimental/test_torch_tensor_dag.py | 89 ++++++++++--------- 5 files changed, 127 insertions(+), 65 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 4b98633103f8..4dcd9edf007a 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -37,8 +37,8 @@ from ray.experimental.channel.shared_memory_channel import ( SharedMemoryType, - TorchTensorType, ) +from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.experimental.channel.torch_tensor_nccl_channel import ( _init_nccl_group, @@ -451,7 +451,9 @@ def prepare(self): for type_hint in self.input_type_hints: if type_hint.requires_nccl(): nccl_group_id = _get_nccl_group_id(type_hint) - nccl_group = ChannelContext.get_current().nccl_groups.get(nccl_group_id) + nccl_group = ChannelContext.get_current().nccl_groups.get( + nccl_group_id + ) assert nccl_group is not None if not isinstance(self._recv_stream, nullcontext): assert self._recv_stream == nccl_group.recv_stream, ( @@ -1096,7 +1098,9 @@ def _preprocess(self) -> None: "Expected P2P actor handles to be a subset of the custom NCCL group" ) self._nccl_group_id_p2p = _init_nccl_group( - nccl_actors_p2p, self._custom_nccl_group_p2p, self._overlap_gpu_communication + nccl_actors_p2p, + self._custom_nccl_group_p2p, + self._overlap_gpu_communication, ) custom_nccl_group_to_id[ self._custom_nccl_group_p2p @@ -1126,7 +1130,9 @@ def _preprocess(self) -> None: self._nccl_group_id_p2p = actors_to_nccl_group_id[actors] else: self._nccl_group_id_p2p = _init_nccl_group( - nccl_actors_p2p, self._custom_nccl_group_p2p, self._overlap_gpu_communication + nccl_actors_p2p, + self._custom_nccl_group_p2p, + self._overlap_gpu_communication, ) actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p @@ -1624,8 +1630,8 @@ def _generate_dag_operation_graph_node( for exec_task_idx, exec_task in enumerate(executable_tasks): # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE, # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. - task_index = exec_task.task_idx - dag_node = self.idx_to_task[task_index].dag_node + task_idx = exec_task.task_idx + dag_node = self.idx_to_task[task_idx].dag_node method_name = exec_task.method_name actor_handle = dag_node._get_actor_handle() requires_nccl = dag_node.type_hint.requires_nccl() @@ -1639,7 +1645,7 @@ def _generate_dag_operation_graph_node( _DAGNodeOperation( exec_task_idx, _DAGNodeOperationType.READ, method_name ), - task_index, + task_idx, actor_handle, upstream_requires_nccl, ) @@ -1647,7 +1653,7 @@ def _generate_dag_operation_graph_node( _DAGNodeOperation( exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name ), - task_index, + task_idx, actor_handle, False, ) @@ -1655,7 +1661,7 @@ def _generate_dag_operation_graph_node( _DAGNodeOperation( exec_task_idx, _DAGNodeOperationType.WRITE, method_name ), - task_index, + task_idx, actor_handle, requires_nccl, ) @@ -1714,22 +1720,22 @@ def _build_execution_schedule( # Step 2: Generate an execution schedule for each actor using topological sort actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) - # Step 3: Optimize the execution schedule if configured - if self._optimize_execution_schedule: + # Step 3: Overlap GPU communication for the execution schedule if configured + if self._overlap_gpu_communication: actor_to_overlapped_schedule = _generate_overlapped_execution_schedule( actor_to_execution_schedule ) - else: - actor_to_overlapped_schedule = None - from ray.dag.constants import RAY_ADAG_VISUALIZE_SCHEDULE + from ray.dag.constants import RAY_ADAG_VISUALIZE_SCHEDULE - if RAY_ADAG_VISUALIZE_SCHEDULE: - _visualize_execution_schedule( - actor_to_execution_schedule, actor_to_overlapped_schedule, graph - ) - - return _extract_execution_schedule(actor_to_overlapped_schedule) + if RAY_ADAG_VISUALIZE_SCHEDULE: + _visualize_execution_schedule( + actor_to_execution_schedule, actor_to_overlapped_schedule, graph + ) + return _extract_execution_schedule(actor_to_overlapped_schedule) + else: + actor_to_overlapped_schedule = None + return _extract_execution_schedule(actor_to_execution_schedule) def _detect_deadlock(self) -> bool: """ diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index d9cb3fe7c2e6..789f70acd297 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -1,7 +1,6 @@ from functools import total_ordering from enum import Enum from typing import Set, Tuple, List, Dict, Optional -from typing import Optional, Tuple, List, Dict import copy import logging import ray diff --git a/python/ray/dag/tests/experimental/test_execution_schedule.py b/python/ray/dag/tests/experimental/test_execution_schedule.py index ffb14e78e60c..b66564c6392a 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule.py @@ -74,6 +74,7 @@ def set_ready_collective_idxs( _DAGNodeOperationType.COMPUTE ].ready_collective_idxs = ready_collective_idxs + def _generate_and_extract_execution_schedule(graph): return _extract_execution_schedule(_generate_actor_to_execution_schedule(graph)) diff --git a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py index 634d411272f2..412848c9ee03 100644 --- a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py +++ b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py @@ -15,7 +15,7 @@ ) from ray.tests.conftest import * # noqa from ray.tests.conftest import wait_for_condition -from ray.dag import InputNode +from ray.dag import InputNode, MultiOutputNode def error_logged(capsys, msg): @@ -416,6 +416,59 @@ def test_p2p_static_shape_and_direct_return( wait_for_condition(lambda: error_logged(capsys, msg)) +@pytest.mark.parametrize( + "ray_start_cluster", + [ + { + "num_cpus": 3, + "num_gpus": 3, + "num_nodes": 1, + } + ], + indirect=True, +) +@pytest.mark.parametrize("overlap_gpu_communication", [False]) +def test_overlap_gpu_communication(ray_start_cluster, overlap_gpu_communication): + sender1 = MockedWorker.remote() + sender2 = MockedWorker.remote() + receiver = MockedWorker.remote() + + ray.get( + [ + sender1.start_mock.remote(), + sender2.start_mock.remote(), + receiver.start_mock.remote(), + ] + ) + + shape = (10,) + dtype = torch.float16 + + with InputNode() as inp: + branch1 = sender1.send.bind(shape, dtype, inp) + branch1 = branch1.with_type_hint( + TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) + ) + branch2 = sender2.send.bind(shape, dtype, inp) + branch2 = branch2.with_type_hint( + TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) + ) + branch1 = receiver.recv.bind(branch1) + branch2 = receiver.recv.bind(branch2) + dag = MultiOutputNode([branch1, branch2]) + + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) + + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == [(i, shape, dtype)] * 2 + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index dcf4e15e6278..a6beed00cf20 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -10,10 +10,10 @@ import ray.cluster_utils import ray.experimental.collective as collective import torch +import time from ray.air._internal import torch_utils -from ray.dag import InputNode, MultiOutputNode +from ray.dag import InputNode from ray.exceptions import RayChannelError -from python.ray.dag.compiled_dag_node import GPUFuture from ray.dag.output_node import MultiOutputNode from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, @@ -227,55 +227,55 @@ def test_torch_tensor_nccl(ray_start_regular): assert ray.get(ref) == (i, shape, dtype) -@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) -def test_torch_tensor_nccl_overlap(ray_start_regular, monkeypatch): - if not USE_GPU: - pytest.skip("NCCL tests require GPUs") +# @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +# def test_torch_tensor_nccl_overlap(ray_start_regular, monkeypatch): +# if not USE_GPU: +# pytest.skip("NCCL tests require GPUs") - assert ( - sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 3 - ), "This test requires at least 3 GPUs" +# assert ( +# sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 3 +# ), "This test requires at least 3 GPUs" - original_gpu_future_init = GPUFuture.__init__ - def mock_gpu_future_init(self, *args, **kwargs): - init_ts = time.monotonic() - original_gpu_future_init(self, *args, **kwargs) +# original_gpu_future_init = GPUFuture.__init__ +# def mock_gpu_future_init(self, *args, **kwargs): +# init_ts = time.monotonic() +# original_gpu_future_init(self, *args, **kwargs) - monkeypatch.setattr(GPUFuture, "__init__", mock_gpu_future_init) +# monkeypatch.setattr(GPUFuture, "__init__", mock_gpu_future_init) - worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) - sender1 = worker_cls.remote() - sender2 = worker_cls.remote() - receiver = worker_cls.remote() +# worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) +# sender1 = worker_cls.remote() +# sender2 = worker_cls.remote() +# receiver = worker_cls.remote() - shape = (10000, ) - dtype = torch.float16 +# shape = (10000, ) +# dtype = torch.float16 - with InputNode() as inp: - branches = [sender.send.bind(shape, dtype, inp) for sender in senders] - branches = [ - branch.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) - ) - for branch in branches - ] - branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] - dag = MultiOutputNode(branches) +# with InputNode() as inp: +# branches = [sender.send.bind(shape, dtype, inp) for sender in senders] +# branches = [ +# branch.with_type_hint( +# TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) +# ) +# for branch in branches +# ] +# branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] +# dag = MultiOutputNode(branches) - # Test normal execution. - compiled_dag = dag.experimental_compile( - _overlap_gpu_communication=overlap_gpu_communication - ) +# # Test normal execution. +# compiled_dag = dag.experimental_compile( +# _overlap_gpu_communication=overlap_gpu_communication +# ) - start = time.monotonic() - for i in range(5): - ref = compiled_dag.execute(i) - result = ray.get(ref) - assert result == [(i, shape, dtype)] * num_senders - duration = time.monotonic() - start - print(f"{overlap_gpu_communication=}, {duration=}") +# start = time.monotonic() +# for i in range(5): +# ref = compiled_dag.execute(i) +# result = ray.get(ref) +# assert result == [(i, shape, dtype)] * num_senders +# duration = time.monotonic() - start +# print(f"{overlap_gpu_communication=}, {duration=}") - compiled_dag.teardown() +# compiled_dag.teardown() @pytest.mark.parametrize( @@ -967,7 +967,10 @@ def test_torch_tensor_nccl_direct_return_error(ray_start_regular): @pytest.mark.parametrize("static_shape", [False, True]) @pytest.mark.parametrize("direct_return", [False, True]) -def test_torch_tensor_exceptions(ray_start_regular, static_shape, direct_return): +@pytest.mark.parametrize("overlap_gpu_communication", [False, True]) +def test_torch_tensor_exceptions( + ray_start_regular, static_shape, direct_return, overlap_gpu_communication +): """ Test exceptions being thrown by a NCCL sending task. """ From 2497dd5189f82b71e6e8e14c0b67f15b19e112de Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Sun, 27 Oct 2024 17:46:31 -0700 Subject: [PATCH 03/16] up Signed-off-by: Rui Qiao --- python/ray/dag/__init__.py | 3 +- .../experimental/test_mocked_nccl_dag.py | 12 +- .../experimental/test_torch_tensor_dag.py | 118 ------------------ python/ray/experimental/channel/conftest.py | 25 ++++ 4 files changed, 38 insertions(+), 120 deletions(-) diff --git a/python/ray/dag/__init__.py b/python/ray/dag/__init__.py index bc081be76e50..bc9970899cf4 100644 --- a/python/ray/dag/__init__.py +++ b/python/ray/dag/__init__.py @@ -11,7 +11,7 @@ DAGInputData, ) from ray.dag.output_node import MultiOutputNode -from ray.dag.dag_operation_future import DAGOperationFuture +from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture from ray.dag.constants import ( PARENT_CLASS_NODE_KEY, PREV_CLASS_METHOD_CALL_KEY, @@ -30,6 +30,7 @@ "DAGNode", "DAGOperationFuture", "FunctionNode", + "GPUFuture", "InputNode", "InputAttributeNode", "DAGInputData", diff --git a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py index 412848c9ee03..9ffe00116cd7 100644 --- a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py +++ b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py @@ -427,8 +427,18 @@ def test_p2p_static_shape_and_direct_return( ], indirect=True, ) -@pytest.mark.parametrize("overlap_gpu_communication", [False]) +@pytest.mark.parametrize("overlap_gpu_communication", [True]) def test_overlap_gpu_communication(ray_start_cluster, overlap_gpu_communication): + # Barrier name should be barrier-{sender rank}-{receiver rank}. + # Create a barrier in both directions because we don't know which rank will + # get assigned to sender and receiver. + barriers = [ # noqa + Barrier.options(name=f"barrier-{i}-{j}").remote() + for i in range(3) + for j in range(3) + if i != j + ] + sender1 = MockedWorker.remote() sender2 = MockedWorker.remote() receiver = MockedWorker.remote() diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index a6beed00cf20..f1bf961032db 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -227,57 +227,6 @@ def test_torch_tensor_nccl(ray_start_regular): assert ray.get(ref) == (i, shape, dtype) -# @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) -# def test_torch_tensor_nccl_overlap(ray_start_regular, monkeypatch): -# if not USE_GPU: -# pytest.skip("NCCL tests require GPUs") - -# assert ( -# sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 3 -# ), "This test requires at least 3 GPUs" - -# original_gpu_future_init = GPUFuture.__init__ -# def mock_gpu_future_init(self, *args, **kwargs): -# init_ts = time.monotonic() -# original_gpu_future_init(self, *args, **kwargs) - -# monkeypatch.setattr(GPUFuture, "__init__", mock_gpu_future_init) - -# worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) -# sender1 = worker_cls.remote() -# sender2 = worker_cls.remote() -# receiver = worker_cls.remote() - -# shape = (10000, ) -# dtype = torch.float16 - -# with InputNode() as inp: -# branches = [sender.send.bind(shape, dtype, inp) for sender in senders] -# branches = [ -# branch.with_type_hint( -# TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) -# ) -# for branch in branches -# ] -# branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] -# dag = MultiOutputNode(branches) - -# # Test normal execution. -# compiled_dag = dag.experimental_compile( -# _overlap_gpu_communication=overlap_gpu_communication -# ) - -# start = time.monotonic() -# for i in range(5): -# ref = compiled_dag.execute(i) -# result = ray.get(ref) -# assert result == [(i, shape, dtype)] * num_senders -# duration = time.monotonic() - start -# print(f"{overlap_gpu_communication=}, {duration=}") - -# compiled_dag.teardown() - - @pytest.mark.parametrize( "ray_start_regular, overlap_gpu_communication", [({"num_cpus": 4}, False), ({"num_cpus": 4}, True)], @@ -456,73 +405,6 @@ def destroy(self) -> None: from cupy.cuda import nccl - class TestNcclGroup(GPUCommunicator): - """ - A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. - """ - - def __init__(self, world_size, comm_id, actor_handles): - self._world_size = world_size - self._comm_id = comm_id - self._actor_handles = actor_handles - self._inner = None - - def initialize(self, rank: int) -> None: - self._inner = _NcclGroup( - self._world_size, - self._comm_id, - rank, - self._actor_handles, - torch.cuda.current_stream().cuda_stream, - ) - - def get_rank(self, actor: ray.actor.ActorHandle) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - actor_ids = [a._ray_actor_id for a in self._actor_handles] - try: - rank = actor_ids.index(actor._ray_actor_id) - except ValueError: - raise ValueError("Actor is not in the NCCL group.") - return rank - - def get_world_size(self) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - return self._world_size - - def get_self_rank(self) -> Optional[int]: - if self._inner is None: - return None - return self._inner.get_self_rank() - - def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: - return self._actor_handles - - def send(self, value: "torch.Tensor", peer_rank: int) -> None: - return self._inner.send(value, peer_rank) - - def recv( - self, - shape: Tuple[int], - dtype: "torch.dtype", - peer_rank: int, - allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": - return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) - - def allreduce( - self, - send_buf: "torch.Tensor", - recv_buf: "torch.Tensor", - op: ReduceOp = ReduceOp.SUM, - ) -> None: - self._inner.allreduce(send_buf, recv_buf, op) - recv_buf += 1 - - def destroy(self) -> None: - return self._inner.destroy() - comm_id = nccl.get_unique_id() nccl_group = TestNcclGroup(2, comm_id, [sender, receiver]) with InputNode() as inp: diff --git a/python/ray/experimental/channel/conftest.py b/python/ray/experimental/channel/conftest.py index 3a4175e607f6..156ae7b48d81 100644 --- a/python/ray/experimental/channel/conftest.py +++ b/python/ray/experimental/channel/conftest.py @@ -1,4 +1,5 @@ import asyncio +import time from collections import defaultdict from typing import Optional, Tuple from unittest import mock @@ -6,6 +7,7 @@ import torch import ray +import ray.dag import ray.experimental.channel as ray_channel from ray.experimental.channel.gpu_communicator import TorchTensorAllocator @@ -105,6 +107,21 @@ def destroy(self) -> None: ray.kill(barrier) +class MockGPUFuture: + def __init__(self, buf, stream=None): + from ray.dag.dag_operation_future import GPUFuture + + self._inner = GPUFuture(buf, stream) + self._init_ts = time.monotonic() + print(f"Created GPUFuture at {self._init_ts}") + + def wait(self): + result = self._inner.wait() + self._wait_ts = time.monotonic() + print(f"Waited for {self._wait_ts - self._init_ts} seconds") + return result + + def start_nccl_mock(): """ Patch methods that require CUDA. @@ -130,6 +147,10 @@ def start_nccl_mock(): "torch.cuda.current_stream", new_callable=lambda: MockCudaStream ) stream_patcher.start() + new_stream_patcher = mock.patch( + "torch.cuda.Stream", new_callable=lambda: MockCudaStream + ) + new_stream_patcher.start() tensor_patcher = mock.patch("torch.Tensor.device", torch.device("cuda")) tensor_patcher.start() tensor_patcher = mock.patch("torch.Tensor.is_cuda", True) @@ -140,6 +161,10 @@ def start_nccl_mock(): ) tensor_allocator_patcher.start() + # Mock GPUFuture + ray.dag.dag_operation_future.GPUFuture = MockGPUFuture + print(f"{ray.dag.dag_operation_future.GPUFuture=}") + ctx = ray_channel.ChannelContext.get_current() ctx.set_torch_device(torch.device("cuda")) From b0e3239145e3894e700be0856b4bfa622b97eb90 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Sun, 27 Oct 2024 17:49:23 -0700 Subject: [PATCH 04/16] up Signed-off-by: Rui Qiao --- .../ray/dag/tests/experimental/test_execution_schedule_gpu.py | 4 ++-- python/ray/dag/tests/experimental/test_torch_tensor_dag.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py index 12eeb7e4f208..b39d1091925e 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py @@ -392,13 +392,13 @@ def test_overlap_gpu_communication(ray_start_regular, overlap_gpu_communication) branch1 = sender1.send.bind(shape, dtype, inp) branch1 = branch1.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) ) branch1 = receiver.recv.bind(branch1) branch2 = sender2.send.bind(shape, dtype, inp) branch2 = branch2.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) ) branch2 = receiver.recv.bind(branch2) dag = MultiOutputNode([branch1, branch2]) diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index f1bf961032db..56cc43ca3f48 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -252,7 +252,9 @@ def test_torch_tensor_nccl_overlap_timed(ray_start_regular, overlap_gpu_communic branches = [sender.send.bind(shape, dtype, inp) for sender in senders] branches = [ branch.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + TorchTensorType( + transport="nccl", _static_shape=True, _direct_return=True + ) ) for branch in branches ] From 4e10e5e575d29ad310391ae2cb063846ee5d9ec8 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 28 Oct 2024 15:48:29 +0000 Subject: [PATCH 05/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 2 +- .../experimental/test_torch_tensor_dag.py | 70 +++++-------------- 2 files changed, 20 insertions(+), 52 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 4dcd9edf007a..31b7a914bdd1 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1655,7 +1655,7 @@ def _generate_dag_operation_graph_node( ), task_idx, actor_handle, - False, + isinstance(dag_node, CollectiveOutputNode), ) write_node = _DAGOperationGraphNode( _DAGNodeOperation( diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 56cc43ca3f48..a2ad34165722 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -394,6 +394,15 @@ def recv( ) -> "torch.Tensor": return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + self._inner.allreduce(send_buf, recv_buf, op) + recv_buf += 1 + @property def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: return self._inner.recv_stream @@ -798,57 +807,6 @@ def test_torch_tensor_nccl_nested_dynamic(ray_start_regular): @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) -def test_torch_tensor_nccl_direct_return_error(ray_start_regular): - if not USE_GPU: - pytest.skip("NCCL tests require GPUs") - - assert ( - sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 - ), "This test requires at least 2 GPUs" - - actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) - - sender = actor_cls.remote() - receiver = actor_cls.remote() - - shape = (10,) - dtype = torch.float16 - - # Passing a non-tensor value when _direct_return=True and tranport="nccl" - # fails. - with InputNode() as inp: - dag = sender.send.bind(inp.shape, inp.dtype, inp.value, inp.send_tensor) - dag = dag.with_type_hint( - TorchTensorType( - transport=TorchTensorType.NCCL, - _direct_return=True, - ) - ) - dag = receiver.recv.bind(dag) - - compiled_dag = dag.experimental_compile() - - ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=True) - assert ray.get(ref) == (1, shape, dtype) - - ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=False) - with pytest.raises(RayChannelError): - ray.get(ref) - - # For direct_return=True tensors, the DAG will be torn down after any task - # throws an application-level exception, such as when the task returns - # something other than a torch.Tensor. Check that we can no longer submit - # to the DAG. - with pytest.raises(RayChannelError): - ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=True) - - compiled_dag.teardown() - - # TODO(swang): This currently requires time.sleep to avoid some issue with - # following tests. - time.sleep(3) - - @pytest.mark.parametrize("static_shape", [False, True]) @pytest.mark.parametrize("direct_return", [False, True]) @pytest.mark.parametrize("overlap_gpu_communication", [False, True]) @@ -1103,6 +1061,8 @@ class TestNcclGroup(GPUCommunicator): A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. """ + import cupy as cp + def __init__(self, world_size, comm_id, actor_handles): self._world_size = world_size self._comm_id = comm_id @@ -1162,6 +1122,14 @@ def allreduce( self._inner.allreduce(send_buf, recv_buf, op) recv_buf += 1 + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._inner.recv_stream + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return self._inner.send_stream + def destroy(self) -> None: return self._inner.destroy() From 0098fe15c36e2dc3804cd766d9359a01b1c0d0d9 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Tue, 29 Oct 2024 17:35:09 +0000 Subject: [PATCH 06/16] up Signed-off-by: Rui Qiao --- python/ray/dag/dag_node_operation.py | 33 ++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 789f70acd297..8f5c58e9b7c7 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -214,8 +214,7 @@ def __str__(self): class_name = ( self.actor_handle._ray_actor_creation_function_descriptor.class_name ) - actor_id = self._actor_id.hex() - actor_id_abbv = actor_id[:4] + "..." + actor_id_abbv = self._actor_id[:4] + "..." return ( class_name + "_" @@ -524,6 +523,36 @@ def _visualize_execution_schedule( color = "blue" if label == "nccl" else "black" dot.edge(node_repr, out_node_repr, label=label, color=color) + # Add legend + with dot.subgraph(name="cluster_legend") as legend: + legend.attr(label="Legend", labelloc="t", fontsize="20", bgcolor="lightgrey") + + # Single node and its explanation + legend.node("example_node", "Worker_3c6a... [0] bwd C 10,10\n") + explanation = ( + '<' # noqa + '' + '' # noqa + "" + '' + '' # noqa + '' # noqa + '' # noqa + '' # noqa + '' # noqa + "" + '' + '' # noqa + "" + '' + '' # noqa + '' # noqa + "
Node description format:
<actor_name>_<actor_id> [<task_index>] <method_name> <operation> <orig_index>, <overlap_index>
Node description fields:
actor_id: is abbreviated, only the first 4 characters are shown
operation: is R(READ), C(COMPUTE), or W(WRITE)
orig_index: the index in the original execution schedule
overlap_index: the index in the overlap-communication optimized execution schedule
If this is different from orig_index, the node is highlighted in red color
Node grouping:
The nodes belonging to the same actor are grouped in the same rectangular
Edges:
blue color: indicates NCCL channel
black color: indicates shared memory channel
>" + ) + + legend.node("example_explanation", explanation, shape="plaintext") + legend.edge("example_node", "example_explanation", style="invis") + logger.info( "Writing compiled graph schedule visualization " "to compiled_graph_schedule.png" From b6bb9ddd6b10b374878864aca8034545f3139a17 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Tue, 29 Oct 2024 13:56:35 -0700 Subject: [PATCH 07/16] up Signed-off-by: Rui Qiao --- python/ray/dag/dag_node_operation.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 8f5c58e9b7c7..274e5713ba9b 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -473,6 +473,26 @@ def _visualize_execution_schedule( """ Visualize the execution schedule for each actor. + The visualization will be saved as a PNG file named `compiled_graph_schedule.png`. + Details of the visualization: # noqa + + Node description format: + _ [] , + + Node description fields: + actor_id: is abbreviated, only the first 4 characters are shown + operation: is R(READ), C(COMPUTE), or W(WRITE) + orig_index: the index in the original execution schedule + overlap_index: the index in the overlap-communication optimized execution schedule + If this is different from orig_index, the node is highlighted in red color + + Node grouping: + The nodes belonging to the same actor are grouped in the same rectangular + + Edges: + blue color: indicates NCCL channel + black color: indicates shared memory channel + Args: actor_to_execution_schedule: A dictionary that maps an actor handle to the execution schedule which is a list of operation nodes. From c672aa9aa151ef147d07fd8c058488132e45f399 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Tue, 29 Oct 2024 14:02:35 -0700 Subject: [PATCH 08/16] up Signed-off-by: Rui Qiao --- python/ray/dag/dag_node_operation.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 274e5713ba9b..6eee99cf8fc8 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -23,7 +23,13 @@ class _DAGNodeOperationType(Enum): COMPUTE = "COMPUTE" WRITE = "WRITE" - def __str__(self): + def short_str(self): + """ + A short string representation of the operation type. + + Used in scenarios that conciseness is preferred, e.g., + in visualization of the execution schedule. + """ if self == _DAGNodeOperationType.READ: return "R" elif self == _DAGNodeOperationType.COMPUTE: @@ -220,7 +226,7 @@ def __str__(self): + "_" + actor_id_abbv + f" [{self.operation.exec_task_idx}] " - + f"{self.operation.method_name} {self.operation.type}" + + f"{self.operation.method_name} {self.operation.type.short_str()}" ) @property From 2a733a75f5b10ee54d0488b758e9992cd5f8f21f Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 30 Oct 2024 00:19:05 +0000 Subject: [PATCH 09/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 1 - .../tests/experimental/test_collective_dag.py | 246 +---------------- .../ray/experimental/collective/conftest.py | 248 ++++++++++++++++++ 3 files changed, 255 insertions(+), 240 deletions(-) create mode 100644 python/ray/experimental/collective/conftest.py diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 31b7a914bdd1..495ba0047725 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -445,7 +445,6 @@ def prepare(self): if self.output_type_hint.requires_nccl(): nccl_group_id = _get_nccl_group_id(self.output_type_hint) nccl_group = ChannelContext.get_current().nccl_groups.get(nccl_group_id) - assert nccl_group is not None self._send_stream = nccl_group.send_stream if self.input_type_hints: for type_hint in self.input_type_hints: diff --git a/python/ray/dag/tests/experimental/test_collective_dag.py b/python/ray/dag/tests/experimental/test_collective_dag.py index 680e6fd27dfb..2d65bb3e55c9 100644 --- a/python/ray/dag/tests/experimental/test_collective_dag.py +++ b/python/ray/dag/tests/experimental/test_collective_dag.py @@ -2,24 +2,18 @@ import logging import os import sys -import uuid -import copy -from typing import Dict, FrozenSet, List, Optional, Set, Tuple +import ray.experimental.collective as collective import pytest -import ray -import ray.cluster_utils -import ray.experimental.collective as collective -import torch +from ray.experimental.collective.conftest import ( + AbstractNcclGroup, + CPUTorchTensorWorker, + check_nccl_group_init, + check_nccl_group_teardown, +) from ray.dag import InputNode, MultiOutputNode from ray.experimental.channel.torch_tensor_type import TorchTensorType -from ray.experimental.channel.common import ChannelContext -from ray.experimental.channel.gpu_communicator import ( - GPUCommunicator, - TorchTensorAllocator, -) from ray.tests.conftest import * # noqa -from ray.util.collective.types import ReduceOp logger = logging.getLogger(__name__) @@ -27,232 +21,6 @@ pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True) -@ray.remote -class CPUTorchTensorWorker: - def __init__(self): - self.device = "cpu" - - def return_tensor(self, size: int) -> torch.Tensor: - return torch.ones(size, device=self.device) - - def recv(self, tensor: torch.Tensor) -> Tuple[int, int]: - assert tensor.device == self.device - return tensor.shape, tensor[0] - - -def mock_do_init_nccl_group( - self, - group_id: str, - rank: int, - actors: List[ray.actor.ActorHandle], - custom_nccl_group: Optional[GPUCommunicator], -) -> None: - ctx = ChannelContext.get_current() - if custom_nccl_group is None: - nccl_group = AbstractNcclGroup(actors) - nccl_group.initialize(rank) - ctx.nccl_groups[group_id] = nccl_group - else: - custom_nccl_group.initialize(rank) - ctx.nccl_groups[group_id] = custom_nccl_group - - -def mock_do_destroy_nccl_group(self, group_id: str) -> None: - ctx = ChannelContext.get_current() - if group_id not in ctx.nccl_groups: - return - ctx.nccl_groups[group_id].destroy() - del ctx.nccl_groups[group_id] - - -class AbstractNcclGroup(GPUCommunicator): - """ - A dummy NCCL group for testing. - """ - - def __init__(self, actor_handles: List[ray.actor.ActorHandle]): - self._actor_handles = actor_handles - self._rank = None - - def initialize(self, rank: int) -> None: - self._rank = rank - - def get_rank(self, actor: ray.actor.ActorHandle) -> int: - return self._actor_handles.index(actor) - - def get_world_size(self) -> int: - return len(self._actor_handles) - - def get_self_rank(self) -> Optional[int]: - return self._rank - - def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: - return self._actor_handles - - def send(self, value: "torch.Tensor", peer_rank: int) -> None: - raise NotImplementedError - - def recv( - self, - shape: Tuple[int], - dtype: "torch.dtype", - peer_rank: int, - allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": - raise NotImplementedError - - def allreduce( - self, - send_buf: "torch.Tensor", - recv_buf: "torch.Tensor", - op: ReduceOp = ReduceOp.SUM, - ) -> None: - raise NotImplementedError - - def destroy(self) -> None: - pass - - -class MockNcclGroupSet: - def __init__(self): - # Represents a mapping from a NCCL group ID to a set of actors and a custom - # NCCL group. - self.ids_to_actors_and_custom_comms: Dict[ - str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] - ] = {} - - def __call__( - self, - actors: List["ray.actor.ActorHandle"], - custom_nccl_group: Optional[GPUCommunicator] = None, - ) -> str: - group_id = str(uuid.uuid4()) - self.ids_to_actors_and_custom_comms[group_id] = ( - frozenset(actors), - custom_nccl_group, - ) - - if custom_nccl_group is None: - ranks = list(range(len(actors))) - else: - ranks = [custom_nccl_group.get_rank(actor) for actor in actors] - init_tasks = [ - actor.__ray_call__.remote( - mock_do_init_nccl_group, - group_id, - rank, - actors, - custom_nccl_group, - ) - for rank, actor in zip(ranks, actors) - ] - ray.get(init_tasks, timeout=30) - - ctx = ChannelContext.get_current() - if custom_nccl_group is not None: - ctx.nccl_groups[group_id] = custom_nccl_group - else: - ctx.nccl_groups[group_id] = AbstractNcclGroup(actors) - - return group_id - - def mock_destroy_nccl_group(self, group_id: str) -> None: - ctx = ChannelContext.get_current() - if group_id not in ctx.nccl_groups: - return - - actors, _ = self.ids_to_actors_and_custom_comms[group_id] - destroy_tasks = [ - actor.__ray_call__.remote( - mock_do_destroy_nccl_group, - group_id, - ) - for actor in actors - ] - ray.wait(destroy_tasks, timeout=30) - - if group_id in self.ids_to_actors_and_custom_comms: - del self.ids_to_actors_and_custom_comms[group_id] - ctx.nccl_groups[group_id].destroy() - del ctx.nccl_groups[group_id] - - def check_init( - self, - compiled_dag: "ray.dag.CompiledDAG", - actors_and_custom_comms: Set[ - Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] - ], - p2p_actors_and_custom_comm: Optional[ - Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] - ], - ) -> None: - assert len(self.ids_to_actors_and_custom_comms) == len(actors_and_custom_comms) - assert ( - set(self.ids_to_actors_and_custom_comms.values()) == actors_and_custom_comms - ) - - nccl_group_id_p2p = compiled_dag.nccl_group_id_p2p - if p2p_actors_and_custom_comm is None: - assert nccl_group_id_p2p is None - else: - assert nccl_group_id_p2p - assert ( - self.ids_to_actors_and_custom_comms[nccl_group_id_p2p] - == p2p_actors_and_custom_comm - ) - - def check_teardown(self, nccl_group_ids: List[str]) -> None: - ctx = ChannelContext.get_current() - for nccl_group_id in nccl_group_ids: - assert nccl_group_id not in self.ids_to_actors_and_custom_comms - assert nccl_group_id not in ctx.nccl_groups - - -def check_nccl_group_init( - monkeypatch, - dag: "ray.dag.DAGNode", - actors_and_custom_comms: Set[ - Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] - ], - p2p_actors_and_custom_comm: Optional[ - Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] - ] = None, -) -> "ray.dag.CompiledDAG": - mock_nccl_group_set = MockNcclGroupSet() - monkeypatch.setattr( - "ray.dag.compiled_dag_node._init_nccl_group", - mock_nccl_group_set, - ) - monkeypatch.setattr( - "ray.dag.collective_node._init_nccl_group", - mock_nccl_group_set, - ) - - compiled_dag = dag.experimental_compile() - mock_nccl_group_set.check_init( - compiled_dag, - actors_and_custom_comms, - p2p_actors_and_custom_comm, - ) - - return compiled_dag, mock_nccl_group_set - - -def check_nccl_group_teardown( - monkeypatch, - compiled_dag: "ray.dag.CompiledDAG", - mock_nccl_group_set: MockNcclGroupSet, -): - monkeypatch.setattr( - "ray.dag.compiled_dag_node._destroy_nccl_group", - mock_nccl_group_set.mock_destroy_nccl_group, - ) - - nccl_group_ids = copy.deepcopy(compiled_dag.nccl_group_ids) - compiled_dag.teardown() - mock_nccl_group_set.check_teardown(nccl_group_ids) - - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_all_reduce_duplicate_actors(ray_start_regular): """ diff --git a/python/ray/experimental/collective/conftest.py b/python/ray/experimental/collective/conftest.py new file mode 100644 index 000000000000..bf34c5916db4 --- /dev/null +++ b/python/ray/experimental/collective/conftest.py @@ -0,0 +1,248 @@ +import copy +from typing import Dict, FrozenSet, List, Optional, Set, Tuple +import uuid +import torch +import ray +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + ReduceOp, + TorchTensorAllocator, +) +from ray.experimental.channel.common import ChannelContext + + +class AbstractNcclGroup(GPUCommunicator): + """ + A dummy NCCL group for testing. + """ + + import cupy as cp + + def __init__(self, actor_handles: List[ray.actor.ActorHandle]): + self._actor_handles = actor_handles + self._rank = None + + def initialize(self, rank: int) -> None: + self._rank = rank + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + return self._actor_handles.index(actor) + + def get_world_size(self) -> int: + return len(self._actor_handles) + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + raise NotImplementedError + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + raise NotImplementedError + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + raise NotImplementedError + + @property + def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + + @property + def send_stream(self) -> Optional["cp.cuda.ExternalStream"]: + return None + + def destroy(self) -> None: + pass + + +class MockNcclGroupSet: + def __init__(self): + # Represents a mapping from a NCCL group ID to a set of actors and a custom + # NCCL group. + self.ids_to_actors_and_custom_comms: Dict[ + str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ] = {} + + def __call__( + self, + actors: List["ray.actor.ActorHandle"], + custom_nccl_group: Optional[GPUCommunicator] = None, + use_communication_streams: bool = False, + ) -> str: + group_id = str(uuid.uuid4()) + self.ids_to_actors_and_custom_comms[group_id] = ( + frozenset(actors), + custom_nccl_group, + ) + + if custom_nccl_group is None: + ranks = list(range(len(actors))) + else: + ranks = [custom_nccl_group.get_rank(actor) for actor in actors] + init_tasks = [ + actor.__ray_call__.remote( + mock_do_init_nccl_group, + group_id, + rank, + actors, + custom_nccl_group, + ) + for rank, actor in zip(ranks, actors) + ] + ray.get(init_tasks, timeout=30) + + ctx = ChannelContext.get_current() + if custom_nccl_group is not None: + ctx.nccl_groups[group_id] = custom_nccl_group + else: + ctx.nccl_groups[group_id] = AbstractNcclGroup(actors) + + return group_id + + def mock_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.nccl_groups: + return + + actors, _ = self.ids_to_actors_and_custom_comms[group_id] + destroy_tasks = [ + actor.__ray_call__.remote( + mock_do_destroy_nccl_group, + group_id, + ) + for actor in actors + ] + ray.wait(destroy_tasks, timeout=30) + + if group_id in self.ids_to_actors_and_custom_comms: + del self.ids_to_actors_and_custom_comms[group_id] + ctx.nccl_groups[group_id].destroy() + del ctx.nccl_groups[group_id] + + def check_init( + self, + compiled_dag: "ray.dag.CompiledDAG", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + ) -> None: + assert len(self.ids_to_actors_and_custom_comms) == len(actors_and_custom_comms) + assert ( + set(self.ids_to_actors_and_custom_comms.values()) == actors_and_custom_comms + ) + + nccl_group_id_p2p = compiled_dag.nccl_group_id_p2p + if p2p_actors_and_custom_comm is None: + assert nccl_group_id_p2p is None + else: + assert nccl_group_id_p2p + assert ( + self.ids_to_actors_and_custom_comms[nccl_group_id_p2p] + == p2p_actors_and_custom_comm + ) + + def check_teardown(self, nccl_group_ids: List[str]) -> None: + ctx = ChannelContext.get_current() + for nccl_group_id in nccl_group_ids: + assert nccl_group_id not in self.ids_to_actors_and_custom_comms + assert nccl_group_id not in ctx.nccl_groups + +@ray.remote +class CPUTorchTensorWorker: + def __init__(self): + self.device = "cpu" + + def return_tensor(self, size: int) -> torch.Tensor: + return torch.ones(size, device=self.device) + + def recv(self, tensor: torch.Tensor) -> Tuple[int, int]: + assert tensor.device == self.device + return tensor.shape, tensor[0] + + + +def mock_do_init_nccl_group( + self, + group_id: str, + rank: int, + actors: List[ray.actor.ActorHandle], + custom_nccl_group: Optional[GPUCommunicator], +) -> None: + ctx = ChannelContext.get_current() + if custom_nccl_group is None: + nccl_group = AbstractNcclGroup(actors) + nccl_group.initialize(rank) + ctx.nccl_groups[group_id] = nccl_group + else: + custom_nccl_group.initialize(rank) + ctx.nccl_groups[group_id] = custom_nccl_group + + +def mock_do_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.nccl_groups: + return + ctx.nccl_groups[group_id].destroy() + del ctx.nccl_groups[group_id] + + +def check_nccl_group_init( + monkeypatch, + dag: "ray.dag.DAGNode", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ] = None, +) -> "ray.dag.CompiledDAG": + mock_nccl_group_set = MockNcclGroupSet() + monkeypatch.setattr( + "ray.dag.compiled_dag_node._init_nccl_group", + mock_nccl_group_set, + ) + monkeypatch.setattr( + "ray.dag.collective_node._init_nccl_group", + mock_nccl_group_set, + ) + + compiled_dag = dag.experimental_compile() + mock_nccl_group_set.check_init( + compiled_dag, + actors_and_custom_comms, + p2p_actors_and_custom_comm, + ) + + return compiled_dag, mock_nccl_group_set + + +def check_nccl_group_teardown( + monkeypatch, + compiled_dag: "ray.dag.CompiledDAG", + mock_nccl_group_set: MockNcclGroupSet, +): + monkeypatch.setattr( + "ray.dag.compiled_dag_node._destroy_nccl_group", + mock_nccl_group_set.mock_destroy_nccl_group, + ) + + nccl_group_ids = copy.deepcopy(compiled_dag.nccl_group_ids) + compiled_dag.teardown() + mock_nccl_group_set.check_teardown(nccl_group_ids) From 56e1517b52d7c571b32bcaed388b3054da33f7ab Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 30 Oct 2024 22:14:08 +0000 Subject: [PATCH 10/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 10 ++++++++-- python/ray/experimental/collective/conftest.py | 10 ++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 495ba0047725..d459c4f0f1be 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -136,7 +136,7 @@ def do_exec_tasks( """ try: for task in tasks: - task.prepare() + task.prepare(overlap_gpu_communication) done = False while True: @@ -427,7 +427,7 @@ def cancel(self): self.input_reader.close() self.output_writer.close() - def prepare(self): + def prepare(self, overlap_gpu_communication: bool = False): """ Prepare the task for execution. The `exec_operation` function can only be called after `prepare` has been called. @@ -442,9 +442,15 @@ def prepare(self): self._send_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() self._recv_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() + if not overlap_gpu_communication: + return + + # Set up send_stream and recv_stream when overlap_gpu_communication + # is configured if self.output_type_hint.requires_nccl(): nccl_group_id = _get_nccl_group_id(self.output_type_hint) nccl_group = ChannelContext.get_current().nccl_groups.get(nccl_group_id) + assert nccl_group is not None self._send_stream = nccl_group.send_stream if self.input_type_hints: for type_hint in self.input_type_hints: diff --git a/python/ray/experimental/collective/conftest.py b/python/ray/experimental/collective/conftest.py index bf34c5916db4..6ae75c1a7b77 100644 --- a/python/ray/experimental/collective/conftest.py +++ b/python/ray/experimental/collective/conftest.py @@ -1,14 +1,16 @@ import copy -from typing import Dict, FrozenSet, List, Optional, Set, Tuple import uuid +from typing import Dict, FrozenSet, List, Optional, Set, Tuple + import torch + import ray +from ray.experimental.channel.common import ChannelContext from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, ReduceOp, TorchTensorAllocator, ) -from ray.experimental.channel.common import ChannelContext class AbstractNcclGroup(GPUCommunicator): @@ -56,7 +58,7 @@ def allreduce( op: ReduceOp = ReduceOp.SUM, ) -> None: raise NotImplementedError - + @property def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: return None @@ -164,6 +166,7 @@ def check_teardown(self, nccl_group_ids: List[str]) -> None: assert nccl_group_id not in self.ids_to_actors_and_custom_comms assert nccl_group_id not in ctx.nccl_groups + @ray.remote class CPUTorchTensorWorker: def __init__(self): @@ -177,7 +180,6 @@ def recv(self, tensor: torch.Tensor) -> Tuple[int, int]: return tensor.shape, tensor[0] - def mock_do_init_nccl_group( self, group_id: str, From 0fba86defd816c00bc5fd735bebfca6ecd6aebd8 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 30 Oct 2024 17:30:26 -0700 Subject: [PATCH 11/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 22 ++++++++---- python/ray/dag/dag_node_operation.py | 31 ++++++++-------- python/ray/dag/dag_operation_future.py | 35 ++++++++++++------- .../experimental/channel/torch_tensor_type.py | 2 +- 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index d459c4f0f1be..884a056a4efd 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -238,12 +238,8 @@ def _get_nccl_group_id(type_hint: ChannelOutputType) -> Optional[str]: The NCCL group ID if the type hint requires NCCL, otherwise None. """ if type_hint.requires_nccl(): - if isinstance(type_hint, SharedMemoryType): - assert type_hint._contains_type.requires_nccl() - return _get_nccl_group_id(type_hint._contains_type) - else: - assert isinstance(type_hint, TorchTensorType) - return type_hint.nccl_group_id + assert isinstance(type_hint, TorchTensorType) + return type_hint.nccl_group_id return None @@ -431,6 +427,10 @@ def prepare(self, overlap_gpu_communication: bool = False): """ Prepare the task for execution. The `exec_operation` function can only be called after `prepare` has been called. + + Args: + overlap_gpu_communication: Whether to overlap GPU communication with + computation during DAG execution to improve performance """ for typ_hint in self.input_type_hints: typ_hint.register_custom_serializer() @@ -493,6 +493,11 @@ def reset_and_wait_intermediate_future(self) -> Any: """ Reset the intermediate future and wait for the result. + This does not block the CPU because: + - If the future is a ResolvedFuture, the result is immediately returned. + - If the future is a GPUFuture, the result is only waited by the current + CUDA stream, and the CPU is not blocked. + Returns: The result of a READ or COMPUTE operation from the intermediate future. """ @@ -1008,6 +1013,11 @@ def _preprocess(self) -> None: # Collect NCCL collective operations. if isinstance(dag_node, CollectiveOutputNode): nccl_collective_ops.add(dag_node.collective_op) + assert not self._overlap_gpu_communication, ( + "Currently, the overlap_gpu_communication option is not " + "supported for NCCL collective operations. Please set " + "overlap_gpu_communication=False." + ) elif isinstance(dag_node, InputNode): if dag_node.type_hint.requires_nccl(): raise ValueError( diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 6eee99cf8fc8..f56f4952d22a 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -23,9 +23,9 @@ class _DAGNodeOperationType(Enum): COMPUTE = "COMPUTE" WRITE = "WRITE" - def short_str(self): + def viz_str(self): """ - A short string representation of the operation type. + A string representation of the operation type to be used in visualization. Used in scenarios that conciseness is preferred, e.g., in visualization of the execution schedule. @@ -216,7 +216,10 @@ def is_nccl_write(self) -> bool: def is_nccl_op(self) -> bool: return self.is_nccl_collective or self.is_nccl_write - def __str__(self): + def viz_str(self): + """ + A string representation of the node to be used in visualization. + """ class_name = ( self.actor_handle._ray_actor_creation_function_descriptor.class_name ) @@ -226,7 +229,7 @@ def __str__(self): + "_" + actor_id_abbv + f" [{self.operation.exec_task_idx}] " - + f"{self.operation.method_name} {self.operation.type.short_str()}" + + f"{self.operation.method_name} {self.operation.type.viz_str()}" ) @property @@ -455,16 +458,16 @@ def _build_dag_node_operation_graph( return graph -def _node_repr(node: _DAGOperationGraphNode, idx: int, optimized_index: int): +def _node_viz_str(node: _DAGOperationGraphNode, idx: int, optimized_index: int): """ - Representation of a node in the visualization of the execution schedule. + A string representation of a node in the visualization of the execution schedule. Args: node: The node to be represented. idx: The index of the node in the execution schedule. optimized_index: The index of the node in the optimized execution schedule. """ - return str(node) + f" {idx},{optimized_index}" + return node.viz_str() + f" {idx},{optimized_index}" def _visualize_execution_schedule( @@ -519,7 +522,7 @@ def _visualize_execution_schedule( ) dot = graphviz.Digraph(comment="DAG") - node_to_repr: Dict[_DAGOperationGraphNode, str] = {} + node_to_viz: Dict[_DAGOperationGraphNode, str] = {} # TODO: only visualize the execution schedule if the overlapped schedule is None. if actor_to_overlapped_schedule is None: @@ -534,20 +537,20 @@ def _visualize_execution_schedule( subgraph.attr(rank=execution_nodes[0]._actor_id) for i, node in enumerate(execution_nodes): optimized_index = node_to_optimized_index.get(node) - node_repr = _node_repr(node, i, optimized_index) + node_viz = _node_viz_str(node, i, optimized_index) color = "red" if optimized_index != i else "black" - subgraph.node(node_repr, node_repr, color=color) - node_to_repr[node] = node_repr + subgraph.node(node_viz, node_viz, color=color) + node_to_viz[node] = node_viz for actor, execution_nodes in actor_to_execution_schedule.items(): for i, node in enumerate(execution_nodes): - node_repr = node_to_repr[node] + node_viz = node_to_viz[node] for out_edge, label in node.out_edges.items(): out_task_idx, out_op_type = out_edge out_node = graph[out_task_idx][out_op_type] - out_node_repr = node_to_repr[out_node] + out_node_repr = node_to_viz[out_node] color = "blue" if label == "nccl" else "black" - dot.edge(node_repr, out_node_repr, label=label, color=color) + dot.edge(node_viz, out_node_repr, label=label, color=color) # Add legend with dot.subgraph(name="cluster_legend") as legend: diff --git a/python/ray/dag/dag_operation_future.py b/python/ray/dag/dag_operation_future.py index 8ab165cb09fd..e92707309ff1 100644 --- a/python/ray/dag/dag_operation_future.py +++ b/python/ray/dag/dag_operation_future.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: - import torch import cupy as cp T = TypeVar("T") @@ -51,19 +50,29 @@ def wait(self): @DeveloperAPI -class GPUFuture(DAGOperationFuture["torch.Tensor"]): +class GPUFuture(DAGOperationFuture[Any]): """ - A future that represents a GPU operation. + A future for a GPU event on a CUDA stream. + + This future wraps a buffer, and records an event on the given stream + when it is created. When the future is waited on, it makes the current + CUDA stream wait on the event, then returns the buffer. + + The buffer must be a GPU tensor produced by an earlier operation launched + on the given stream, or it could be CPU data. Then the future guarantees + that when the wait() returns, the buffer is ready on the current stream. + + The future does not block CPU. """ - def __init__(self, buf: "torch.Tensor", stream: Optional["cp.cuda.Stream"] = None): + def __init__(self, buf: Any, stream: Optional["cp.cuda.Stream"] = None): """ - Initialize a GPU future. + Initialize a GPU future on the given stream. Args: buf: The buffer to return when the future is resolved. - stream: The CUDA stream to record the event on. If None, the current - stream is used. + stream: The CUDA stream to record the event on, this event is waited + on when the future is resolved. If None, the current stream is used. """ import cupy as cp @@ -74,13 +83,13 @@ def __init__(self, buf: "torch.Tensor", stream: Optional["cp.cuda.Stream"] = Non self._event = cp.cuda.Event() self._event.record(stream) - def wait(self) -> "torch.Tensor": + def wait(self) -> Any: """ - Wait for the future and return the result from the GPU operation. + Wait for the future on the current CUDA stream and return the result from + the GPU operation. This operation does not block CPU. """ import cupy as cp - if self._event is not None: - current_stream = cp.cuda.get_current_stream() - current_stream.wait_event(self._event) + current_stream = cp.cuda.get_current_stream() + current_stream.wait_event(self._event) return self._buf diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index 44ebe83e5f8d..8615f18b7d65 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -157,7 +157,7 @@ def create_channel( # Data does not require NCCL. Transfer via host memory using a # shared-memory channel. - # TODO(swang): Allow the initial max buffer size to bereaders overridden. + # TODO(swang): Allow the initial max buffer size to be overridden. typ = SharedMemoryType() return typ.create_channel(writer, reader_and_node_list, read_by_adag_driver) From a2404726048c345f4b3f2207ae4684614a726ae5 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 31 Oct 2024 09:17:21 -0700 Subject: [PATCH 12/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 11 ++++--- python/ray/dag/dag_node_operation.py | 43 +++++++++++++++++----------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 884a056a4efd..8d546b4f347c 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -474,12 +474,12 @@ def wrap_and_set_intermediate_future( Wrap the value in a `DAGOperationFuture` and store to the intermediate future. The value corresponds to result of a READ or COMPUTE operation. - If wrap_in_gpu_future is True, the value will be wrapped in a _GPUFuture, + If wrap_in_gpu_future is True, the value will be wrapped in a GPUFuture, Otherwise, the future will be a ResolvedFuture. Args: val: The value to wrap in a future. - wrap_in_gpu_future: Whether to wrap the value in a _GPUFuture. + wrap_in_gpu_future: Whether to wrap the value in a GPUFuture. """ assert self._intermediate_future is None @@ -493,7 +493,7 @@ def reset_and_wait_intermediate_future(self) -> Any: """ Reset the intermediate future and wait for the result. - This does not block the CPU because: + The wait does not block the CPU because: - If the future is a ResolvedFuture, the result is immediately returned. - If the future is a GPUFuture, the result is only waited by the current CUDA stream, and the CPU is not blocked. @@ -521,7 +521,7 @@ def _read(self, overlap_gpu_communication: bool) -> bool: try: input_data = self.input_reader.read() # When overlap_gpu_communication is enabled, wrap the result in - # a GPU future so that this read operation (communication) can + # a GPUFuture so that this read operation (communication) can # be overlapped with computation. self.wrap_and_set_intermediate_future( input_data, wrap_in_gpu_future=overlap_gpu_communication @@ -579,7 +579,7 @@ def _compute( except Exception as exc: output_val = _wrap_exception(exc) - # When overlap_gpu_communication is enabled, wrap the result in a GPU future + # When overlap_gpu_communication is enabled, wrap the result in a GPUFuture # so that this compute operation can be overlapped with communication. self.wrap_and_set_intermediate_future( output_val, wrap_in_gpu_future=overlap_gpu_communication @@ -1749,7 +1749,6 @@ def _build_execution_schedule( ) return _extract_execution_schedule(actor_to_overlapped_schedule) else: - actor_to_overlapped_schedule = None return _extract_execution_schedule(actor_to_execution_schedule) def _detect_deadlock(self) -> bool: diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index f56f4952d22a..e1682bc8a9c2 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -27,8 +27,7 @@ def viz_str(self): """ A string representation of the operation type to be used in visualization. - Used in scenarios that conciseness is preferred, e.g., - in visualization of the execution schedule. + The result string is a single character because conciseness is preferred. """ if self == _DAGNodeOperationType.READ: return "R" @@ -54,7 +53,7 @@ def __init__( than tasks that appear in the current compiled DAG. operation_type: The type of operation to perform. method_name: The name of the method that this operation originates - from. This is only for debugging purposes. + from. This is only for visualization and debugging purposes. """ self.exec_task_idx = exec_task_idx self.type = operation_type @@ -67,15 +66,18 @@ def __repr__(self): f" type: {self.type})" ) - def __str__(self): - return f"([{self.exec_task_idx}] {self.method_name} {self.type})" + def vis_str(self): + """ + A string representation of the node to be used in visualization. + """ + return f"([{self.exec_task_idx}] {self.method_name} {self.type.viz_str()})" def __hash__(self): return hash((self.exec_task_idx, self.type)) def __eq__(self, other): # An operation is uniquely identified by its `exec_task_idx` and type. - # `func_name` is only for debugging purposes. + # `method_name` is only for debugging purposes. return self.exec_task_idx == other.exec_task_idx and self.type == other.type @@ -109,10 +111,11 @@ def __init__( # Each tuple (the key) contains an integer `task_idx`, which can be # used to index into `idx_to_task` to get the corresponding task, # and a `_DAGNodeOperationType`, which can be READ, COMPUTE, or WRITE. - # The string (the value) is the label of the edge, which will be used - # to annotate the edge in the visualization of the execution schedule. - self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], str] = {} - self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], str] = {} + # The string (the value) is the visualization information of the edge, + # it is a tuple of a label of the edge and a boolean indicating whether + # the edge is a control dependency. + self.in_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {} + self.out_edges: Dict[Tuple[int, _DAGNodeOperationType], Tuple[str, bool]] = {} # The collective nodes are the nodes that belong to the same collective # operation. Each node is represented by a tuple of its task idx and type. self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() @@ -238,11 +241,13 @@ def _actor_id(self): def _add_edge( - from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode, label: str = "" + from_node: _DAGOperationGraphNode, + to_node: _DAGOperationGraphNode, + label: str = "", + control_dependency: bool = False, ): """ - Add an edge from `from_node` to `to_node`. An edge is a tuple of - the operation's `task_idx` and type. + Add an edge from `from_node` to `to_node`. Args: from_node: The node from which the edge originates. @@ -250,8 +255,14 @@ def _add_edge( label: The label of the edge. This will be used to annotate the edge in the visualization of the execution schedule. """ - from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = label - to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = label + from_node.out_edges[(to_node.task_idx, to_node.operation.type)] = ( + label, + control_dependency, + ) + to_node.in_edges[(from_node.task_idx, from_node.operation.type)] = ( + label, + control_dependency, + ) def _push_candidate_node_if_ready( @@ -408,7 +419,7 @@ def _build_dag_node_operation_graph( # Add an edge from COMPUTE with `bind_index` i to COMPUTE with # `bind_index` i+1 if they belong to the same actor. if prev_compute_node is not None: - _add_edge(prev_compute_node, compute_node, "next") + _add_edge(prev_compute_node, compute_node, "", True) prev_compute_node = compute_node assert task_idx not in graph graph[task_idx] = { From bcf1115f1618f30f8e45929c2a7c4d47b6502281 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 31 Oct 2024 09:35:32 -0700 Subject: [PATCH 13/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 2 + python/ray/dag/dag_node_operation.py | 21 ++++-- .../experimental/test_mocked_nccl_dag.py | 65 +------------------ python/ray/experimental/channel/conftest.py | 20 ------ .../experimental/channel/gpu_communicator.py | 1 - 5 files changed, 17 insertions(+), 92 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 8d546b4f347c..a63fc53e7f60 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1749,6 +1749,8 @@ def _build_execution_schedule( ) return _extract_execution_schedule(actor_to_overlapped_schedule) else: + if RAY_ADAG_VISUALIZE_SCHEDULE: + _visualize_execution_schedule(actor_to_execution_schedule, None, graph) return _extract_execution_schedule(actor_to_execution_schedule) def _detect_deadlock(self) -> bool: diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index e1682bc8a9c2..2a68c483dc4f 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -510,8 +510,9 @@ def _visualize_execution_schedule( The nodes belonging to the same actor are grouped in the same rectangular Edges: - blue color: indicates NCCL channel - black color: indicates shared memory channel + black color: indicates shared memory channel (also annotated with "shm") + blue color: indicates NCCL channel (also annotated with "nccl") + dashed edge: indicates a control dependency between compute nodes Args: actor_to_execution_schedule: A dictionary that maps an actor handle to @@ -535,8 +536,9 @@ def _visualize_execution_schedule( dot = graphviz.Digraph(comment="DAG") node_to_viz: Dict[_DAGOperationGraphNode, str] = {} - # TODO: only visualize the execution schedule if the overlapped schedule is None. if actor_to_overlapped_schedule is None: + # TODO(rui): make the visualization more concise by only displaying + # the original schedule actor_to_overlapped_schedule = actor_to_execution_schedule for actor, execution_nodes in actor_to_execution_schedule.items(): overlapped_schedule = actor_to_overlapped_schedule[actor] @@ -556,12 +558,14 @@ def _visualize_execution_schedule( for actor, execution_nodes in actor_to_execution_schedule.items(): for i, node in enumerate(execution_nodes): node_viz = node_to_viz[node] - for out_edge, label in node.out_edges.items(): + for out_edge, viz_info in node.out_edges.items(): + label, control_dependency = viz_info out_task_idx, out_op_type = out_edge out_node = graph[out_task_idx][out_op_type] out_node_repr = node_to_viz[out_node] color = "blue" if label == "nccl" else "black" - dot.edge(node_viz, out_node_repr, label=label, color=color) + style = "dashed" if control_dependency else "solid" + dot.edge(node_viz, out_node_repr, label=label, color=color, style=style) # Add legend with dot.subgraph(name="cluster_legend") as legend: @@ -585,8 +589,9 @@ def _visualize_execution_schedule( 'The nodes belonging to the same actor are grouped in the same rectangular' # noqa "" 'Edges:' - 'blue color: indicates NCCL channel' # noqa - 'black color: indicates shared memory channel' # noqa + 'black color: indicates shared memory channel (also annotated with "shm")' # noqa + 'blue color: indicates NCCL channel (also annotated with "nccl")' # noqa + 'dashed edge: indicates a control dependency between compute nodes' # noqa ">" ) @@ -696,6 +701,8 @@ def _generate_overlapped_execution_schedule( compute node to swap with so that the NCCL read operation can be overlapped with computation. + Collective operations are not yet supported. + Args: actor_to_execution_schedule: A dictionary that maps an actor handle to the existing execution schedule for the actor. The schedule is a list diff --git a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py index 9ffe00116cd7..634d411272f2 100644 --- a/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py +++ b/python/ray/dag/tests/experimental/test_mocked_nccl_dag.py @@ -15,7 +15,7 @@ ) from ray.tests.conftest import * # noqa from ray.tests.conftest import wait_for_condition -from ray.dag import InputNode, MultiOutputNode +from ray.dag import InputNode def error_logged(capsys, msg): @@ -416,69 +416,6 @@ def test_p2p_static_shape_and_direct_return( wait_for_condition(lambda: error_logged(capsys, msg)) -@pytest.mark.parametrize( - "ray_start_cluster", - [ - { - "num_cpus": 3, - "num_gpus": 3, - "num_nodes": 1, - } - ], - indirect=True, -) -@pytest.mark.parametrize("overlap_gpu_communication", [True]) -def test_overlap_gpu_communication(ray_start_cluster, overlap_gpu_communication): - # Barrier name should be barrier-{sender rank}-{receiver rank}. - # Create a barrier in both directions because we don't know which rank will - # get assigned to sender and receiver. - barriers = [ # noqa - Barrier.options(name=f"barrier-{i}-{j}").remote() - for i in range(3) - for j in range(3) - if i != j - ] - - sender1 = MockedWorker.remote() - sender2 = MockedWorker.remote() - receiver = MockedWorker.remote() - - ray.get( - [ - sender1.start_mock.remote(), - sender2.start_mock.remote(), - receiver.start_mock.remote(), - ] - ) - - shape = (10,) - dtype = torch.float16 - - with InputNode() as inp: - branch1 = sender1.send.bind(shape, dtype, inp) - branch1 = branch1.with_type_hint( - TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) - ) - branch2 = sender2.send.bind(shape, dtype, inp) - branch2 = branch2.with_type_hint( - TorchTensorType(transport="nccl", _static_shape=True, _direct_return=True) - ) - branch1 = receiver.recv.bind(branch1) - branch2 = receiver.recv.bind(branch2) - dag = MultiOutputNode([branch1, branch2]) - - compiled_dag = dag.experimental_compile( - _overlap_gpu_communication=overlap_gpu_communication - ) - - for i in range(3): - ref = compiled_dag.execute(i) - result = ray.get(ref) - assert result == [(i, shape, dtype)] * 2 - - compiled_dag.teardown() - - if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/experimental/channel/conftest.py b/python/ray/experimental/channel/conftest.py index 156ae7b48d81..47deb762bec7 100644 --- a/python/ray/experimental/channel/conftest.py +++ b/python/ray/experimental/channel/conftest.py @@ -1,5 +1,4 @@ import asyncio -import time from collections import defaultdict from typing import Optional, Tuple from unittest import mock @@ -107,21 +106,6 @@ def destroy(self) -> None: ray.kill(barrier) -class MockGPUFuture: - def __init__(self, buf, stream=None): - from ray.dag.dag_operation_future import GPUFuture - - self._inner = GPUFuture(buf, stream) - self._init_ts = time.monotonic() - print(f"Created GPUFuture at {self._init_ts}") - - def wait(self): - result = self._inner.wait() - self._wait_ts = time.monotonic() - print(f"Waited for {self._wait_ts - self._init_ts} seconds") - return result - - def start_nccl_mock(): """ Patch methods that require CUDA. @@ -161,10 +145,6 @@ def start_nccl_mock(): ) tensor_allocator_patcher.start() - # Mock GPUFuture - ray.dag.dag_operation_future.GPUFuture = MockGPUFuture - print(f"{ray.dag.dag_operation_future.GPUFuture=}") - ctx = ray_channel.ChannelContext.get_current() ctx.set_torch_device(torch.device("cuda")) diff --git a/python/ray/experimental/channel/gpu_communicator.py b/python/ray/experimental/channel/gpu_communicator.py index acb64c9e5da1..6edd5035471c 100644 --- a/python/ray/experimental/channel/gpu_communicator.py +++ b/python/ray/experimental/channel/gpu_communicator.py @@ -80,7 +80,6 @@ def send(self, value: "torch.Tensor", peer_rank: int) -> None: value: The torch.Tensor to send. It should already be on this actor's default device. peer_rank: The rank of the actor to send to. - future: An optional future to wait on before sending. """ raise NotImplementedError From a7b2027513dd6efd2e70b1919cb3b9350518e598 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 31 Oct 2024 13:12:24 -0700 Subject: [PATCH 14/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 16 ++++++++-------- python/ray/dag/dag_operation_future.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index a63fc53e7f60..144ee5caa037 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -14,6 +14,7 @@ from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture from ray.experimental.channel.cached_channel import CachedChannel from ray.experimental.channel.gpu_communicator import GPUCommunicator +from ray.dag.constants import RAY_ADAG_VISUALIZE_SCHEDULE import ray from ray.exceptions import RayTaskError, RayChannelError from ray.experimental.compiled_dag_ref import ( @@ -136,7 +137,7 @@ def do_exec_tasks( """ try: for task in tasks: - task.prepare(overlap_gpu_communication) + task.prepare(overlap_gpu_communication=overlap_gpu_communication) done = False while True: @@ -1736,21 +1737,20 @@ def _build_execution_schedule( actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) # Step 3: Overlap GPU communication for the execution schedule if configured + actor_to_overlapped_schedule = None if self._overlap_gpu_communication: actor_to_overlapped_schedule = _generate_overlapped_execution_schedule( actor_to_execution_schedule ) - from ray.dag.constants import RAY_ADAG_VISUALIZE_SCHEDULE + if RAY_ADAG_VISUALIZE_SCHEDULE: + _visualize_execution_schedule( + actor_to_execution_schedule, actor_to_overlapped_schedule, graph + ) - if RAY_ADAG_VISUALIZE_SCHEDULE: - _visualize_execution_schedule( - actor_to_execution_schedule, actor_to_overlapped_schedule, graph - ) + if actor_to_overlapped_schedule is not None: return _extract_execution_schedule(actor_to_overlapped_schedule) else: - if RAY_ADAG_VISUALIZE_SCHEDULE: - _visualize_execution_schedule(actor_to_execution_schedule, None, graph) return _extract_execution_schedule(actor_to_execution_schedule) def _detect_deadlock(self) -> bool: diff --git a/python/ray/dag/dag_operation_future.py b/python/ray/dag/dag_operation_future.py index e92707309ff1..33d790515d3c 100644 --- a/python/ray/dag/dag_operation_future.py +++ b/python/ray/dag/dag_operation_future.py @@ -62,7 +62,7 @@ class GPUFuture(DAGOperationFuture[Any]): on the given stream, or it could be CPU data. Then the future guarantees that when the wait() returns, the buffer is ready on the current stream. - The future does not block CPU. + The `wait()` does not block CPU. """ def __init__(self, buf: Any, stream: Optional["cp.cuda.Stream"] = None): From 256cc421bea108c9920b05b111dae0fedf100599 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 31 Oct 2024 13:19:34 -0700 Subject: [PATCH 15/16] up Signed-off-by: Rui Qiao --- python/ray/dag/compiled_dag_node.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 144ee5caa037..aa6241d4c2d6 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -3,7 +3,17 @@ from collections import defaultdict from contextlib import nullcontext from dataclasses import dataclass, asdict -from typing import Any, Dict, FrozenSet, List, Tuple, Union, Optional, Set +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + List, + Tuple, + Union, + Optional, + Set, +) import logging import threading import time @@ -59,6 +69,8 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +if TYPE_CHECKING: + import cupy as cp logger = logging.getLogger(__name__) @@ -439,8 +451,6 @@ def prepare(self, overlap_gpu_communication: bool = False): self.input_reader.start() self.output_writer.start() - import cupy as cp - self._send_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() self._recv_stream: Union["cp.cuda.Stream", nullcontext] = nullcontext() if not overlap_gpu_communication: From 100ba2eb0d6cf3e60eccfd9378b22529a1b1c2f1 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 31 Oct 2024 13:30:58 -0700 Subject: [PATCH 16/16] up Signed-off-by: Rui Qiao --- python/ray/dag/dag_node_operation.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index 2a68c483dc4f..d54333e536af 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -507,12 +507,13 @@ def _visualize_execution_schedule( If this is different from orig_index, the node is highlighted in red color Node grouping: - The nodes belonging to the same actor are grouped in the same rectangular + The nodes belonging to the same actor are grouped in the same rectangle Edges: - black color: indicates shared memory channel (also annotated with "shm") - blue color: indicates NCCL channel (also annotated with "nccl") - dashed edge: indicates a control dependency between compute nodes + black color (without label): data dependency + black color (annotated with "shm"): shared memory channel + blue color (annotated with "nccl): NCCL channel + dashed edge: control dependency between compute operations Args: actor_to_execution_schedule: A dictionary that maps an actor handle to @@ -589,9 +590,10 @@ def _visualize_execution_schedule( 'The nodes belonging to the same actor are grouped in the same rectangular' # noqa "" 'Edges:' - 'black color: indicates shared memory channel (also annotated with "shm")' # noqa - 'blue color: indicates NCCL channel (also annotated with "nccl")' # noqa - 'dashed edge: indicates a control dependency between compute nodes' # noqa + 'black color (without label): data dependency' # noqa + 'black color (annotated with "shm"): shared memory channel' # noqa + 'blue color (annotated with "nccl): NCCL channel' # noqa + 'dashed edge: control dependency between compute operations' # noqa ">" )