diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 733118bf571d..66809cd1f437 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -24,7 +24,10 @@ from ray.dag.dag_operation_future import GPUFuture, DAGOperationFuture, ResolvedFuture from ray.experimental.channel.cached_channel import CachedChannel from ray.experimental.channel.communicator import Communicator -from ray.dag.constants import RAY_CGRAPH_VISUALIZE_SCHEDULE +from ray.dag.constants import ( + RAY_CGRAPH_ENABLE_NVTX_PROFILING, + RAY_CGRAPH_VISUALIZE_SCHEDULE, +) import ray from ray.exceptions import RayTaskError, RayChannelError from ray.experimental.compiled_dag_ref import ( @@ -153,6 +156,17 @@ def do_exec_tasks( for task in tasks: task.prepare(overlap_gpu_communication=overlap_gpu_communication) + if RAY_CGRAPH_ENABLE_NVTX_PROFILING: + try: + import nvtx + except ImportError: + raise ImportError( + "Please install nvtx to enable nsight profiling. " + "You can install it by running `pip install nvtx`." + ) + nvtx_profile = nvtx.Profile() + nvtx_profile.enable() + done = False while True: if done: @@ -163,6 +177,9 @@ def do_exec_tasks( ) if done: break + + if RAY_CGRAPH_ENABLE_NVTX_PROFILING: + nvtx_profile.disable() except Exception: logging.exception("Compiled DAG task exited with exception") raise @@ -1577,7 +1594,6 @@ def _get_or_compile( executable_tasks.sort(key=lambda task: task.bind_index) self.actor_to_executable_tasks[actor_handle] = executable_tasks - # Build an execution schedule for each actor from ray.dag.constants import RAY_CGRAPH_ENABLE_PROFILING if RAY_CGRAPH_ENABLE_PROFILING: @@ -1585,6 +1601,7 @@ def _get_or_compile( else: exec_task_func = do_exec_tasks + # Build an execution schedule for each actor self.actor_to_execution_schedule = self._build_execution_schedule() for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): self.worker_task_refs[actor_handle] = actor_handle.__ray_call__.options( diff --git a/python/ray/dag/constants.py b/python/ray/dag/constants.py index 25f481bd1c37..299acf3137c6 100644 --- a/python/ray/dag/constants.py +++ b/python/ray/dag/constants.py @@ -20,6 +20,13 @@ # Feature flag to turn on profiling. RAY_CGRAPH_ENABLE_PROFILING = os.environ.get("RAY_CGRAPH_ENABLE_PROFILING", "0") == "1" +# Feature flag to turn on NVTX (NVIDIA Tools Extension Library) profiling. +# With this flag, Compiled Graph uses nvtx to automatically annotate and profile +# function calls during each actor's execution loop. +RAY_CGRAPH_ENABLE_NVTX_PROFILING = ( + os.environ.get("RAY_CGRAPH_ENABLE_NVTX_PROFILING", "0") == "1" +) + # Feature flag to turn on visualization of the execution schedule. RAY_CGRAPH_VISUALIZE_SCHEDULE = ( os.environ.get("RAY_CGRAPH_VISUALIZE_SCHEDULE", "0") == "1"