From de958dfe2b73111d8691445c184bc7d181d5081c Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 17 Jan 2024 00:42:12 -0800 Subject: [PATCH 1/5] ip --- tests/models/test_models.py | 11 +++++---- vllm/engine/llm_engine.py | 46 +++++++++++++++++++++++++++++++------ vllm/engine/ray_utils.py | 9 ++++++++ vllm/entrypoints/llm.py | 2 +- 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 40858a517b311..c98caf72ee371 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,11 +5,12 @@ import pytest MODELS = [ - "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", - "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", - "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" + "facebook/opt-125m", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", + # "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", + # "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" ] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e30bf5db49283..8e65e78369957 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,6 +2,7 @@ from collections import defaultdict import os import time +import pickle from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) @@ -122,6 +123,8 @@ def __init__( self.num_prompt_tokens: List[Tuple[float, int]] = [] # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + # Only used for compiled DAG. + self.forward_dag = None def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -730,7 +733,8 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) + }, + use_ray_compiled_dag=True) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -873,6 +877,7 @@ def _run_workers( driver_args: Optional[List[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" @@ -881,11 +886,16 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers - ] + if use_ray_compiled_dag: + if self.forward_dag is None: + self.forward_dag = self._compiled_dag_init_dag() + output_channels = self.forward_dag.execute(1) + else: + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] if driver_args is None: driver_args = args @@ -898,6 +908,28 @@ def _run_workers( # Get the results of the ray workers. if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs + + + def _compiled_dag_init_dag(self): + from ray.dag import MultiOutputNode, InputNode + assert self.parallel_config.worker_use_ray + + with InputNode() as input_data: + forward_dag = MultiOutputNode([ + worker.execute_model_compiled_dag_remote.bind(input_data) + for worker in self.workers + ]) + return forward_dag.experimental_compile() diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index fb8854e068c87..a75af0586ee6e 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -3,6 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import is_hip, set_cuda_visible_devices, get_ip +import pickle logger = init_logger(__name__) @@ -40,6 +41,14 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: def set_cuda_visible_devices(self, device_ids) -> None: set_cuda_visible_devices(device_ids) + def execute_model_compiled_dag_remote(self, ignored): + """Used only when compiled DAG is enabled.""" + print("SANG-TODO execute_model_compiled_dag_remote") + output = self.worker.execute_model() + print("SANG-TODO execute_model_compiled_dag_remote finished") + output = pickle.dumps(output) + return output + except ImportError as e: logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3d..9c912bd97ccac 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -71,7 +71,7 @@ def __init__( tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tensor_parallel_size: int = 1, + tensor_parallel_size: int = 4, dtype: str = "auto", quantization: Optional[str] = None, revision: Optional[str] = None, From 2e65bbb5f761e1b6fd1b4077195182caac12744f Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 17 Jan 2024 14:53:25 -0800 Subject: [PATCH 2/5] done basic version --- vllm/engine/llm_engine.py | 5 ++--- vllm/engine/ray_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8e65e78369957..5da675b97d9ae 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -124,7 +124,8 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] # Only used for compiled DAG. - self.forward_dag = None + self.forward_dag = self._compiled_dag_init_dag() + # self.forward_dag = None def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -887,8 +888,6 @@ def _run_workers( "max_concurrent_workers is not supported yet.") if use_ray_compiled_dag: - if self.forward_dag is None: - self.forward_dag = self._compiled_dag_init_dag() output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index a75af0586ee6e..b3b5225f81d65 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -43,9 +43,9 @@ def set_cuda_visible_devices(self, device_ids) -> None: def execute_model_compiled_dag_remote(self, ignored): """Used only when compiled DAG is enabled.""" - print("SANG-TODO execute_model_compiled_dag_remote") + import torch + torch.cuda.set_device(self.worker.device) output = self.worker.execute_model() - print("SANG-TODO execute_model_compiled_dag_remote finished") output = pickle.dumps(output) return output From c0c7b618764f1588f48bea50cb56fccc64328f16 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 24 Jan 2024 01:44:26 -0800 Subject: [PATCH 3/5] ready --- tests/models/test_models.py | 11 +++++------ vllm/engine/llm_engine.py | 21 +++++++++++++++------ vllm/engine/ray_utils.py | 10 +++++++++- vllm/entrypoints/llm.py | 2 +- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c98caf72ee371..40858a517b311 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,12 +5,11 @@ import pytest MODELS = [ - "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", - # "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", - # "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", - # "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" + "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", + "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", + "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" ] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dce0f60dd2f54..20523ba86f34c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -31,6 +31,11 @@ _LOGGING_INTERVAL_SEC = 5 +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -123,9 +128,11 @@ def __init__( self.num_prompt_tokens: List[Tuple[float, int]] = [] # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] - # Only used for compiled DAG. - self.forward_dag = self._compiled_dag_init_dag() - # self.forward_dag = None + # Only used when USE_RAY_COMPILED_DAG is enabled. + # Stores the DAG definition to run VLLM workers. + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_dag_init_dag() def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -751,7 +758,7 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, }, - use_ray_compiled_dag=True) + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -926,7 +933,8 @@ def _run_workers( if use_ray_compiled_dag: try: ray_worker_outputs = [ - pickle.loads(chan.begin_read()) for chan in output_channels + pickle.loads(chan.begin_read()) + for chan in output_channels ] finally: # Has to call end_read in order to reuse the DAG. @@ -937,11 +945,12 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs - def _compiled_dag_init_dag(self): from ray.dag import MultiOutputNode, InputNode assert self.parallel_config.worker_use_ray + # Right now, compiled DAG requires at least 1 arg. We send + # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ worker.execute_model_compiled_dag_remote.bind(input_data) diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index b3b5225f81d65..59fe4a179b5a2 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -19,6 +19,11 @@ def __init__(self, init_cached_hf_modules=False) -> None: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() self.worker = None + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False def init_worker(self, worker_init_fn): self.worker = worker_init_fn() @@ -44,7 +49,10 @@ def set_cuda_visible_devices(self, device_ids) -> None: def execute_model_compiled_dag_remote(self, ignored): """Used only when compiled DAG is enabled.""" import torch - torch.cuda.set_device(self.worker.device) + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + output = self.worker.execute_model() output = pickle.dumps(output) return output diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 18dbbc4507bcb..b819e233c06b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -71,7 +71,7 @@ def __init__( tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tensor_parallel_size: int = 4, + tensor_parallel_size: int = 1, dtype: str = "auto", quantization: Optional[str] = None, revision: Optional[str] = None, From e5a96446eae202ff9917b63c4087d8edf5e2273c Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 7 Feb 2024 07:31:23 -0800 Subject: [PATCH 4/5] Addressed code review. --- vllm/engine/llm_engine.py | 14 ++++++++++++-- vllm/engine/ray_utils.py | 3 ++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ef6783434b933..02e123401865f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -130,7 +130,7 @@ def __init__( # Stores the DAG definition to run VLLM workers. self.forward_dag = None if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_dag_init_dag() + self.forward_dag = self._compiled_ray_dag() def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -976,6 +976,8 @@ def _run_workers( "max_concurrent_workers is not supported yet.") if use_ray_compiled_dag: + # Right now, compiled DAG can only accept a single + # input. TODO(sang): Fix it. output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. @@ -1010,7 +1012,15 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs - def _compiled_dag_init_dag(self): + def _compiled_ray_dag(self): + import pkg_resources + required_version = "2.9" + current_version = pkg_resources.get_distribution("ray").version + if current_version < required_version: + raise ValueError( + f"Ray version {required_version} or greater is " + f"required, but found {current_version}") + from ray.dag import MultiOutputNode, InputNode assert self.parallel_config.worker_use_ray diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 9a3d359c62406..2974dec8f81d1 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,10 @@ +import pickle + from typing import Optional, List, Tuple, TYPE_CHECKING from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import is_hip, set_cuda_visible_devices, get_ip -import pickle logger = init_logger(__name__) From 8ddf045c1b20a9eea45625f0f4fd25b604a375e6 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 8 Feb 2024 06:47:35 -0800 Subject: [PATCH 5/5] lint --- vllm/engine/llm_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8449898e77d35..03a2b1157652b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1028,9 +1028,8 @@ def _compiled_ray_dag(self): required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: - raise ValueError( - f"Ray version {required_version} or greater is " - f"required, but found {current_version}") + raise ValueError(f"Ray version {required_version} or greater is " + f"required, but found {current_version}") from ray.dag import MultiOutputNode, InputNode assert self.parallel_config.worker_use_ray