Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Integration compiled DAG off by default #2471

Merged
merged 9 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import os
import time
import pickle
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)

Expand Down Expand Up @@ -30,6 +31,11 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))


class LLMEngine:
"""An LLM engine that receives requests and generates texts.
Expand Down Expand Up @@ -124,6 +130,10 @@ def __init__(
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)

self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

Expand Down Expand Up @@ -806,7 +816,8 @@ def step(self) -> List[RequestOutput]:
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

# Only the driver worker returns the sampling results.
output = all_outputs[0]
Expand Down Expand Up @@ -966,6 +977,7 @@ def _run_workers(
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
Expand All @@ -974,11 +986,16 @@ def _run_workers(
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]

if driver_args is None:
driver_args = args
Expand All @@ -991,6 +1008,37 @@ def _run_workers(

# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)

return [driver_worker_output] + ray_worker_outputs

def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")

from ray.dag import MultiOutputNode, InputNode
assert self.parallel_config.worker_use_ray

# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 18 additions & 0 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from typing import Optional, List, Tuple, TYPE_CHECKING

from vllm.config import ParallelConfig
Expand All @@ -18,6 +20,11 @@ def __init__(self, init_cached_hf_modules=False) -> None:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False

def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
Expand All @@ -40,6 +47,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True

output = self.worker.execute_model()
output = pickle.dumps(output)
return output

except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
Expand Down
Loading