From 04f262e75f4c8670de47fb1ca303823102709dad Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 04:09:02 -0400 Subject: [PATCH 01/80] soft prompt support --- tests/prompt_adapter/test_bloom.py | 37 +++ tests/worker/test_model_runner.py | 3 +- vllm/adapter_commons/__init__.py | 0 vllm/adapter_commons/models.py | 86 ++++++ vllm/adapter_commons/request.py | 19 ++ vllm/adapter_commons/worker_manager.py | 62 ++++ vllm/config.py | 13 + vllm/core/scheduler.py | 19 ++ vllm/engine/arg_utils.py | 19 +- vllm/engine/llm_engine.py | 41 ++- vllm/entrypoints/llm.py | 9 +- vllm/executor/cpu_executor.py | 12 + vllm/executor/executor_base.py | 18 +- vllm/executor/gpu_executor.py | 16 + vllm/lora/models.py | 139 ++++----- vllm/lora/request.py | 27 +- vllm/lora/worker_manager.py | 179 ++++-------- vllm/model_executor/models/baichuan.py | 2 + vllm/model_executor/models/bloom.py | 2 + vllm/model_executor/models/gpt_bigcode.py | 2 + vllm/model_executor/models/llama.py | 2 + vllm/model_executor/models/mixtral.py | 2 + vllm/prompt_adapter/__init__.py | 0 vllm/prompt_adapter/layers.py | 21 ++ vllm/prompt_adapter/models.py | 341 ++++++++++++++++++++++ vllm/prompt_adapter/request.py | 26 ++ vllm/prompt_adapter/worker_manager.py | 186 ++++++++++++ vllm/sequence.py | 35 ++- vllm/worker/cpu_model_runner.py | 8 +- vllm/worker/cpu_worker.py | 5 +- vllm/worker/model_runner.py | 120 ++++++-- vllm/worker/worker.py | 17 +- 32 files changed, 1227 insertions(+), 241 deletions(-) create mode 100644 tests/prompt_adapter/test_bloom.py create mode 100644 vllm/adapter_commons/__init__.py create mode 100644 vllm/adapter_commons/models.py create mode 100644 vllm/adapter_commons/request.py create mode 100644 vllm/adapter_commons/worker_manager.py create mode 100644 vllm/prompt_adapter/__init__.py create mode 100644 vllm/prompt_adapter/layers.py create mode 100644 vllm/prompt_adapter/models.py create mode 100644 vllm/prompt_adapter/request.py create mode 100644 vllm/prompt_adapter/worker_manager.py diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py new file mode 100644 index 0000000000000..5f09ee1304f76 --- /dev/null +++ b/tests/prompt_adapter/test_bloom.py @@ -0,0 +1,37 @@ +import vllm +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' + + +def do_sample(llm, pa_name: str, pa_id: int): + + prompts = [ + "Tweet text : @nationalgridus I have no water and the bill is \ + current and paid. Can you do something about this? Label : ", + "Tweet text : @nationalgridus Looks good thanks! Label : " + ] + sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=3) + + outputs = llm.generate(prompts, + sampling_params, + prompt_adapter_request=PromptAdapterRequest( + pa_name, pa_id, PA_PATH, 8) if pa_id else None) + + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_twitter_prompt_adapter(): + llm = vllm.LLM(MODEL_PATH) + + expected_output = ['complaint', 'no complaint'] + + assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 514a57e17ebf4..87391e09d34cd 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -21,6 +21,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner @@ -353,7 +354,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) (input_tokens, input_positions, attn_metadata, _, _, _, - _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + _, _, _) = model_runner.prepare_input_tensors(seq_group_metadata_list) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000000000..630aeb9ba0101 --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,86 @@ +import os +import json +from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, TypeVar +import torch +from torch import nn +from vllm.utils import LRUCache +from vllm.logger import init_logger +logger = init_logger(__name__) + +class AdapterModel: + def __init__(self, model_id=None): + self.id = model_id + + @classmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + +T = TypeVar('T') +class AdapterLRUCache(LRUCache[T]): + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], None]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: Hashable, value: T): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + +class AdapterModelManager: + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + + def __len__(self) -> int: + return len(self._registered_adapters) + + def _deactivate_adapter(self, adapter_id: int): + raise NotImplementedError("Subclasses must implement this method.") + + def deactivate_adapter(self, adapter_id: int) -> bool: + if adapter_id in self._active_adapters: + self._deactivate_adapter(adapter_id) + self._active_adapters.pop(adapter_id) + return True + return False + + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError("Subclasses must implement this method.") + + def _add_adapter(self, adapter: Any): + raise NotImplementedError("Subclasses must implement this method.") + + def add_adapter(self, adapter: Any) -> bool: + if adapter.id not in self._registered_adapters: + if len(self._registered_adapters) >= self.capacity: + raise RuntimeError("No free "+self.adapter_type+" slots.") + self._add_adapter(adapter) + return True + return False + + def set_adapter_mapping(self, mapping: Any) -> None: + if self._last_mapping != mapping: + self._set_adapter_mapping(mapping) + self._last_mapping = mapping + + def remove_adapter(self, adapter_id: int) -> bool: + self.deactivate_adapter(adapter_id) + return bool(self._registered_adapters.pop(adapter_id, None)) + + def list_adapters(self) -> Dict[int, Any]: + return dict(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return self._registered_adapters.get(adapter_id, None) \ No newline at end of file diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000000000..26d59babdb987 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +@dataclass +class AdapterRequest: + """ + Base class for adapter requests. + """ + + def __post_init__(self): + if self.adapter_id < 1: + raise ValueError( + f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) \ No newline at end of file diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000000000..ee113f1e2080b --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Dict +import torch + +class AbstractWorkerManager(ABC): + def __init__(self, device: torch.device): + self.device = device + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_manager(self, model: torch.nn.Module) -> Any: + ... + + def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: + self._apply_adapters(requests) + self._model_manager.set_adapter_mapping(mapping) + + @abstractmethod + def add_dummy_adapter(self, request: Any) -> bool: + ... + + def add_adapter(self, adapter_request: Any) -> bool: + if adapter_request.adapter_id in self.list_adapters(): + return False + loaded_adapter = self._load_adapter(adapter_request) + loaded = self._model_manager.add_adapter(loaded_adapter) + self._model_manager.activate_adapter(loaded_adapter.id) + return loaded + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + models_that_exist = self.list_adapters() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > self._model_manager.adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + "than the number of GPU model slots " + f"({self._model_manager.adapter_slots}).") + + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + + for adapter_id in models_to_remove: + self.remove_adapter(adapter_id) + + for adapter_id in models_to_add: + self.add_adapter(models_map[adapter_id]) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._model_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._model_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return set(self._model_manager.list_adapters()) \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 2513d43ce8e6b..3942b11c35c7f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1094,6 +1094,18 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): "LoRA is enabled.") +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_cpu_prompt_adapters: Optional[int] = None + + def __post_init__(self): + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + @dataclass class VisionLanguageConfig: """Configs the input data format and how models should run for @@ -1371,6 +1383,7 @@ class EngineConfig: vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] + prompt_adapter_config: Optional[PromptAdapterConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08ae..4c5b244cb1e42 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -139,6 +140,10 @@ def __post_init__(self): if self.num_loras > 0: self._sort_by_lora_ids() + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + if self.num_prompt_adapters > 0: + self._sort_by_prompt_adapter_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in @@ -149,6 +154,11 @@ def _sort_by_lora_ids(self): self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) + def _sort_by_prompt_adapter_ids(self): + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.prompt_adapter_id, g.seq_group.request_id)) + @property def lora_requests(self) -> Set[LoRARequest]: return { @@ -157,6 +167,14 @@ def lora_requests(self) -> Set[LoRARequest]: if g.seq_group.lora_request is not None } + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + @dataclass class SchedulerRunningOutputs: @@ -1006,6 +1024,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 227de5475b949..4f8006e914a3a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig, VisionLanguageConfig) + TokenizerPoolConfig, VisionLanguageConfig, PromptAdapterConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple @@ -64,8 +64,10 @@ class EngineArgs: tokenizer_pool_type: str = "ray" tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False + enable_prompt_adapter: bool = False max_loras: int = 1 max_lora_rank: int = 16 + max_prompt_adapters: int = 1 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None @@ -501,6 +503,13 @@ def add_cli_args( 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + # parser.add_argument('--enable-prompt-adapter', + # action='store_true', + # help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -721,7 +730,10 @@ def create_engine_config(self, ) -> EngineConfig: download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, ) - + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters) #if self.enable_prompt_adapter else None + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -772,7 +784,8 @@ def create_engine_config(self, ) -> EngineConfig: vision_language_config=vision_language_config, speculative_config=speculative_config, load_config=load_config, - decoding_config=decoding_config) + decoding_config=decoding_config, + prompt_adapter_config=prompt_adapter_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ea754758492f3..fae815f4e3a70 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + VisionLanguageConfig, PromptAdapterConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -24,6 +24,7 @@ from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -154,6 +155,7 @@ def __init__( vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -207,6 +209,7 @@ def __init__( self.speculative_config = speculative_config self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: @@ -230,6 +233,7 @@ def __init__( vision_language_config=vision_language_config, speculative_config=speculative_config, load_config=load_config, + prompt_adapter_config=prompt_adapter_config, ) if not self.model_config.embedding_mode: @@ -336,7 +340,6 @@ def from_engine_args( engine_config = engine_args.create_engine_config() distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor @@ -358,7 +361,6 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor - # Create the LLM engine. engine = cls( **engine_config.to_dict(), @@ -436,6 +438,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -443,7 +446,7 @@ def _add_processed_request( eos_token_id = self._get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + lora_request, prompt_adapter_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -453,6 +456,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request ) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( @@ -461,6 +465,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request ) else: raise ValueError( @@ -474,6 +479,7 @@ def process_model_inputs( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -488,6 +494,10 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -499,6 +509,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -549,7 +560,8 @@ def add_request( processed_inputs = self.process_model_inputs(request_id=request_id, inputs=inputs, - lora_request=lora_request) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -557,6 +569,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request ) def _create_sequence_group_with_sampling( @@ -566,6 +579,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -591,7 +605,8 @@ def _create_sequence_group_with_sampling( seqs=[seq], arrival_time=arrival_time, sampling_params=sampling_params, - lora_request=lora_request) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) return seq_group @@ -602,6 +617,7 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -611,7 +627,8 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, - pooling_params=pooling_params) + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -976,5 +993,15 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> List[int]: return self.model_executor.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + def check_health(self) -> None: self.model_executor.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e923493160ed..ac20d7b13c3b5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -11,6 +11,7 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -250,6 +251,7 @@ def generate( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -299,6 +301,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -499,6 +502,7 @@ def _validate_and_add_requests( params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -521,6 +525,7 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, + prompt_adapter_request=prompt_adapter_request ) def _add_request( @@ -528,12 +533,14 @@ def _add_request( inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, inputs, params, - lora_request=lora_request) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index a2212459f034e..4eb8fb89c399f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -7,6 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -48,6 +49,7 @@ def _init_worker(self): lora_config=self.lora_config, vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=True, ) self.driver_worker.init_device() @@ -87,6 +89,16 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 4d01939c2e38b..145c451bd2f8f 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -3,8 +3,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, PromptAdapterConfig) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -27,6 +28,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig] ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -37,6 +39,7 @@ def __init__( self.device_config = device_config self.vision_language_config = vision_language_config self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config self._init_executor() @@ -90,6 +93,19 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise NotImplementedError + @abstractmethod + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError + @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757ec..c4bfaaa12bb5c 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,6 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -45,6 +46,7 @@ def _get_worker_kwargs( lora_config=self.lora_config, vision_language_config=self.vision_language_config, speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=rank == 0, ) @@ -102,6 +104,20 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3e82856866d85..3b09d75f3daf5 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -19,6 +19,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache, is_pin_memory_available +from vllm.adapter_commons.models import AdapterModel, AdapterLRUCache, AdapterModelManager logger = init_logger(__name__) @@ -152,7 +153,7 @@ def get_lora_id(): return _GLOBAL_LORA_ID -class LoRAModel: +class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" def __init__( @@ -358,7 +359,7 @@ def from_local_checkpoint( ) -class LoRAModelManager: +class LoRAModelManager(AdapterModelManager): """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -410,8 +411,7 @@ def __init__( # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - - self.model: nn.Module = model + super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) @@ -423,12 +423,13 @@ def __init__( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} - self._registered_loras: Dict[int, LoRAModel] = {} + # self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_loras: Dict[int, None] = {} + # self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self + self.adapter_type = 'LoRa' @property def capacity(self) -> int: @@ -438,15 +439,16 @@ def capacity(self) -> int: def lora_slots(self) -> int: return self.lora_config.max_loras - def __len__(self) -> int: - return len(self._registered_loras) + @property + def adapter_slots(self) -> int: + return self.lora_slots def activate_lora( self, lora_id: int, ) -> bool: """Move LoRA into a GPU buffer to be used in the forward pass.""" - if lora_id in self._active_loras: + if lora_id in self._active_adapters: return False first_free_slot = next( ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) @@ -454,8 +456,8 @@ def activate_lora( if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot - self._active_loras[lora_id] = None - lora_model = self._registered_loras[lora_id] + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id @@ -469,6 +471,10 @@ def activate_lora( module.reset_lora(index) return True + @property + def activate_adapter(self): + return self.activate_lora + def _deactivate_lora(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) @@ -476,14 +482,14 @@ def _deactivate_lora(self, lora_id: int): except ValueError: pass - def deactivate_lora(self, lora_id: int) -> bool: - """Remove a LoRA from a GPU buffer.""" - if lora_id in self._active_loras: - self._deactivate_lora(lora_id) - self._active_loras.pop(lora_id) - return True - return False + @property + def _deactivate_adapter(self): + return self._deactivate_lora + @property + def deactivate_lora(self): + return self.deactivate_adapter + def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -501,29 +507,23 @@ def _set_long_lora_context(self, lora: LoRAModel): def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager CPU cache.""" + @property + def _add_adapter(self): + return self._add_lora + + def add_lora(self, lora: LoRAModel): logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - if len(self._registered_loras) >= self.capacity: - raise RuntimeError("No free LoRA slots.") - self._add_lora(lora) - return True - return False + return self.add_adapter(lora) - def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager CPU cache.""" - # TODO: should we check active lora? - self.deactivate_lora(lora_id) - if self.long_lora_context: - self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) - return bool(self._registered_loras.pop(lora_id, None)) + @property + def remove_lora(self): + return self.remove_adapter # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: @@ -548,23 +548,31 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: - if self._last_mapping != lora_mapping: - self._set_lora_mapping(lora_mapping) - self._last_mapping = lora_mapping + @property + def set_lora_mapping(self): + return self.set_adapter_mapping - def list_loras(self) -> Dict[int, LoRAModel]: - """List all registered LoRAModels.""" - return dict(self._registered_loras) + @property + def _set_adapter_mapping(self): + return self._set_lora_mapping - def get_lora(self, lora_id: int) -> Optional[LoRAModel]: - return self._registered_loras.get(lora_id, None) + @property + def list_loras(self): + return self.list_adapters + @property + def get_lora(self): + return self.get_adapter + def remove_all_loras(self): """Remove all LoRAModels from the manager.""" - self._registered_loras.clear() + self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots - self._active_loras.clear() + self._active_adapters.clear() + + @property + def remove_all_adapters(self): + return self.remove_all_loras def _create_lora_modules(self): for module_name, module in self.model.named_modules( @@ -709,17 +717,10 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: replacement_loras) -class LoRALRUCache(LRUCache[LoRAModel]): - - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], - bool]): - super().__init__(capacity) - self.deactivate_lora_fn = deactivate_lora_fn +class LoRALRUCache(AdapterLRUCache[LoRAModel]): - def _on_remove(self, key: int, value: LoRAModel): - logger.debug("Removing LoRA. int id: %d", key) - self.deactivate_lora_fn(key) - return super()._on_remove(key, value) + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], None]): + super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): @@ -735,45 +736,45 @@ def __init__( ): super().__init__(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config) - self._registered_loras: LoRALRUCache = LoRALRUCache( + self._registered_adapters: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_lora) - self._active_loras: LoRALRUCache = LoRALRUCache( + self._active_adapters: LoRALRUCache = LoRALRUCache( self.lora_slots, self._deactivate_lora) - def list_loras(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" - return dict(self._registered_loras.cache) + return dict(self._registered_adapters.cache) - def add_lora(self, lora: LoRAModel) -> bool: + def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: + if lora.id not in self._registered_adapters: self._add_lora(lora) was_added = True else: # We always touch to update the LRU cache order - self._registered_loras.touch(lora.id) + self._registered_adapters.touch(lora.id) was_added = False return was_added - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: - if lora_id not in self._active_loras and len( - self._active_loras) >= self.lora_slots: - self._active_loras.remove_oldest() + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() result = super().activate_lora(lora_id) # We always touch to update the LRU cache order - self._active_loras.touch(lora_id) + self._active_adapters.touch(lora_id) return result def remove_oldest_lora(self) -> bool: - if len(self._registered_loras) > 0: - self._registered_loras.remove_oldest() + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() return True return False diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 662774ffe09ae..eb301311a7c10 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from typing import Optional - +from vllm.adapter_commons.request import AdapterRequest @dataclass -class LoRARequest: +class LoRARequest(AdapterRequest): """ Request for a LoRA adapter. - Note that this class should be be used internally. For online + Note that this class should be used internally. For online serving, it is recommended to not allow users to use this class but instead provide another layer of abstraction to prevent users from accessing unauthorized LoRA adapters. @@ -20,15 +20,16 @@ class LoRARequest: lora_int_id: int lora_local_path: str long_lora_max_len: Optional[int] = None + __hash__ = AdapterRequest.__hash__ - def __post_init__(self): - if self.lora_int_id < 1: - raise ValueError( - f"lora_int_id must be > 0, got {self.lora_int_id}") - - def __eq__(self, value: object) -> bool: - return isinstance( - value, LoRARequest) and self.lora_int_id == value.lora_int_id - - def __hash__(self) -> int: + @property + def adapter_id(self): return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def local_path(self): + return self.lora_local_path \ No newline at end of file diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 4657757bd484b..d247c202d472b 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -10,83 +10,17 @@ from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest +from vllm.adapter_commons.worker_manager import AbstractWorkerManager logger = init_logger(__name__) - -class AbstractWorkerLoRAManager(ABC): - """Abstract class for managing LoRA models on the worker side.""" - - def __init__(self, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - max_position_embeddings: Optional[int] = None): - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self.max_position_embeddings = max_position_embeddings - self.vocab_size = vocab_size - self.device = device - self.lora_config = lora_config - - # If False, do not cache. If None, cache is empty. - self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False - - @contextmanager - def dummy_lora_cache(self): - """Use this context manager to reuse the dummy lora model - to avoid creating it repeatedly.""" - self._cached_dummy_lora = None - yield - self._cached_dummy_lora = False - - @property - @abstractmethod - def is_enabled(self) -> bool: - ... - - @abstractmethod - def create_lora_manager( - self, - model: torch.nn.Module, - ) -> Any: - ... - - @abstractmethod - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - ... - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - ... - - @abstractmethod - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - ... - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - ... - - @abstractmethod - def remove_all_loras(self): - ... - - @abstractmethod - def list_loras(self) -> Set[int]: - ... - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(AbstractWorkerManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: Type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -103,17 +37,24 @@ def __init__( self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) # Lazily initialized by create_lora_manager. self._lora_manager: LoRAModelManager - super().__init__( - max_num_seqs, - max_num_batched_tokens, - vocab_size, - lora_config, - device, - max_position_embeddings=max_position_embeddings, - ) - + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @property def is_enabled(self) -> bool: return True @@ -128,37 +69,14 @@ def create_lora_manager( max_num_batched_tokens=self.max_num_batched_tokens, vocab_size=self.vocab_size, lora_config=self.lora_config, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, ) self._lora_manager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self._apply_loras(lora_requests) - self._lora_manager.set_lora_mapping(lora_mapping) - - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: - loras_that_exist = self.list_loras() - loras_map = { - lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request - } - if len(loras_map) > self._lora_manager.lora_slots: - raise RuntimeError( - f"Number of requested LoRAs ({len(loras_map)}) is greater " - "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") - - new_loras = set(loras_map) - loras_to_add = new_loras - loras_that_exist - loras_to_remove = loras_that_exist - new_loras - - for lora_id in loras_to_remove: - self.remove_lora(lora_id) - - for lora_id in loras_to_add: - self.add_lora(loras_map[lora_id]) + @property + def set_active_loras(self): + return self.set_active_adapters def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: try: @@ -210,22 +128,41 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: self._cached_dummy_lora = dummy_lora return self._lora_manager.add_lora(dummy_lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id in self.list_loras(): - return False - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) - self._lora_manager.activate_lora(lora.id) - return loaded + @property + def add_dummy_adapter(self): + return self.add_dummy_lora - def remove_lora(self, lora_id: int) -> bool: - return self._lora_manager.remove_lora(lora_id) + @property + def create_manager(self): + return self.create_lora_manager + @property + def _load_adapter(self): + return self._load_lora + + @property + def _model_manager(self): + return self._lora_manager + + @property + def add_lora(self): + return self.add_adapter + + @property + def remove_lora(self): + return self.remove_adapter + + @property def remove_all_loras(self): - self._lora_manager.remove_all_loras() + return self.remove_all_adapters + + @property + def list_loras(self): + return self.list_adapters - def list_loras(self) -> Set[int]: - return set(self._lora_manager.list_loras()) + @property + def _apply_loras(self): + return self._apply_adapters class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -235,7 +172,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _lora_manager_cls: Type[ + _manager_cls: Type[ LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( @@ -244,7 +181,7 @@ def create_lora_manager( ) -> Any: lora_manager = create_lora_manager( model, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, lora_config=self.lora_config, @@ -253,7 +190,7 @@ def create_lora_manager( self._lora_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request @@ -266,7 +203,7 @@ def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: for lora in loras_map.values(): self.add_lora(lora) - def add_lora(self, lora_request: LoRARequest) -> bool: + def add_adapter(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id not in self.list_loras(): # Remove before we load the new lora to save memory if len(self._lora_manager) + 1 > self._lora_manager.capacity: diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index babb92e7cdcef..08599941507cc 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from vllm.prompt_adapter.layers import apply_prompt_adapter def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -278,6 +279,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a29aee4cffb7d..29d935d723af4 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from vllm.prompt_adapter.layers import apply_prompt_adapter def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -251,6 +252,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) + hidden_states = apply_prompt_adapter(self, hidden_states, position_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): layer = self.h[i] diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 69b75763e9a3d..b50845f236151 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from vllm.prompt_adapter.layers import apply_prompt_adapter class GPTBigCodeAttention(nn.Module): @@ -219,6 +220,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) + inputs_embeds = apply_prompt_adapter(self, inputs_embeds, position_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d83ee9a201c0b..9618b74f282f8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,6 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once +from vllm.prompt_adapter.layers import apply_prompt_adapter class LlamaMLP(nn.Module): @@ -282,6 +283,7 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3faf54d292b99..b7c6f3db57405 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -53,6 +53,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once +from vllm.prompt_adapter.layers import apply_prompt_adapter class MixtralMoE(nn.Module): @@ -462,6 +463,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 0000000000000..799a2e38feecc --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import Tuple +import torch + +@dataclass +class PromptAdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + +def apply_prompt_adapter(instance, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + if hasattr(instance, 'prefix_encoder'): + soft_prompt = instance.prefix_encoder.prompt_embedding + indices = (positions < soft_prompt.shape[0]) + hidden_states[indices] = soft_prompt[positions[indices]] + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 0000000000000..84eb3e8ee03e1 --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,341 @@ +import logging +from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Tuple +from peft.utils import load_peft_weights +import torch +import math +from torch import nn +from vllm.adapter_commons.models import AdapterModel, AdapterLRUCache, AdapterModelManager +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.utils import LRUCache + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], + max_prompt_adapters: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter ids to PromptAdapter indices. + max_prompt_adapters: Maximum number of PromptAdapters. + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + PromptAdapter indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to PromptAdapter indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to PromptAdapter indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_promt_adapters. + indices_len: List of lengths of the above tensors. + Used to index into each tensor. It contains length for + (base_indices, sampler_indices, sampler_indices_padded). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + prompt_adapter_indices = index_mapping_indices.copy() + prompt_mapping: List[int] = [ + prompt_adapter_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + prompt_adapter_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + prompt_adapter_idx = (prompt_adapter_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + prompt_adapter_indices[i] = prompt_adapter_idx + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, prompt_adapter_indices + ] + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_prompt_adapters - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1] + ] + return (base_indices, sampler_indices, sampler_indices_padded, + indices_len) + + +class PromptAdapterModel(AdapterModel): + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.kv_cache = None + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint(cls, + adapter_model_and_path, + prompt_adapter_id, + torch_device='cuda') -> "PromptAdapterModel": + adapters_weights = load_peft_weights(adapter_model_and_path, + torch_device) + prompt_embedding = adapters_weights["prompt_embeddings"].half() + num_virtual_tokens = prompt_embedding.shape[0] + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[Optional[int]] =\ + [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self._last_mapping = None + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.indices_len: List[Optional[int]] = [None] * 3 + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_prompt_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in \ + enumerate(self.prompt_adapter_index_to_id) + if prompt_adapter_id is None), None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = \ + self._registered_adapters[prompt_adapter_id] + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for module_name, module in self.model.named_modules(): + if 'Model' in (module.__class__.__name__): + module.prefix_encoder = prompt_adapter_model + break + return True + + @property + def activate_adapter(self): + return self.activate_prompt_adapter + + def _deactivate_prompt_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for module_name, module in self.model.named_modules(): + if 'Model' in (module.__class__.__name__): + del module.prefix_encoder + except ValueError: + pass + + @property + def _deactivate_adapter(self): + return self._deactivate_prompt_adapter + + def deactivate_prompt_adapter(self, prompt_adapter_id: int) -> bool: + """Remove a prompt_adapter from a GPU buffer.""" + if prompt_adapter_id in self._active_adapters: + self._deactivate_prompt_adapter(prompt_adapter_id) + self._active_adapters.pop(prompt_adapter_id) + return True + return False + + @property + def deactivate_prompt_adapter(self): + return self.deactivate_adapter + + def _add_prompt_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + @property + def _add_adapter(self): + return self._add_prompt_adapter + + @property + def add_prompt_adapter(self): + return self.add_adapter + + @property + def remove_prompt_adapter(self): + return self.remove_adapter + + def _set_prompt_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + indices_len) = convert_mapping(mapping, self.prompt_adapter_index_to_id, + self.prompt_adapter_slots + 1) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + # Maintain the reference + self.indices_len[:] = indices_len + + @property + def set_prompt_adapter_mapping(self): + return self.set_adapter_mapping + + @property + def _set_adapter_mapping(self): + return self._set_prompt_adapter_mapping + + @property + def list_prompt_adapters(self): + return self.list_adapters + + @property + def get_prompt_adapter(self): + return self.get_adapter + + def remove_all_prompt_adapters(self) -> bool: + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + @property + def remove_all_adapters(self): + return self.remove_all_prompt_adapters + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[Hashable], None]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, \ + max_num_batched_tokens, prompt_adapter_config) + self._registered_adapters: PromptAdapterLRUCache = \ + PromptAdapterLRUCache(self.capacity, + self.deactivate_prompt_adapter) + self._active_adapters: PromptAdapterLRUCache = \ + PromptAdapterLRUCache(self.prompt_adapter_slots, + self._deactivate_prompt_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_prompt_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_prompt_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_prompt_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[PromptAdapterModelManager] \ + = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, max_num_seqs=max_num_seqs, \ + max_num_batched_tokens=max_num_batched_tokens, prompt_adapter_config=prompt_adapter_config, **kwargs) + return prompt_adapter_manager \ No newline at end of file diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 0000000000000..c91cbb86d223c --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from vllm.adapter_commons.request import AdapterRequest + +@dataclass +class PromptAdapterRequest(AdapterRequest): + """ + Request for a Prompt adapter. + """ + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + __hash__ = AdapterRequest.__hash__ + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path \ No newline at end of file diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 0000000000000..e87ecc5aeab80 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,186 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Dict +import math +import torch + +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +logger = logging.getLogger(__name__) + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: Type[PromptAdapterConfig], + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._prompt_adapter_manager: Optional[ + PromptAdapterModelManager] = None + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._prompt_adapter_manager: PromptAdapterModelManager \ + = prompt_adapter_manager + return prompt_adapter_manager.model + + @property + def set_active_prompt_adapters(self): + return self.set_active_adapters + + def _load_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = self._prompt_adapter_model_cls\ + .from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + torch_device=str(self.device) + ) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + pass + + @property + def add_dummy_adapter(self): + return self.add_dummy_prompt_adapter + + @property + def create_manager(self): + return self.create_prompt_adapter_manager + + @property + def _load_adapter(self): + return self._load_prompt_adapter + + @property + def _model_manager(self): + return self._prompt_adapter_manager + + @property + def add_prompt_adapter(self): + return self.add_adapter + + @property + def remove_prompt_adapter(self): + return self.remove_adapter + + @property + def remove_all_prompt_adapters(self): + return self.remove_all_adapters + + @property + def list_prompt_adapters(self): + return self.list_adapters + + @property + def _apply_prompt_adapters(self): + return self._apply_adapters + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._prompt_adapter_manager: \ + LRUCachePromptAdapterModelManager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._prompt_adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._prompt_adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_prompt_adapter(prompt_adapter) + + def add_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in \ + self.list_prompt_adapters(): + # Remove before we load the new prompt_adapter to save memory + if len(self._prompt_adapter_manager + ) + 1 > self._prompt_adapter_manager.capacity: + self._prompt_adapter_manager.remove_oldest_prompt_adapter() + prompt_adapter = self._load_prompt_adapter(prompt_adapter_request) + loaded = self._prompt_adapter_manager.add_prompt_adapter( + prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._prompt_adapter_manager.get_prompt_adapter( + prompt_adapter_request.prompt_adapter_id) + self._prompt_adapter_manager.activate_prompt_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/sequence.py b/vllm/sequence.py index 2f27bf33b166e..eee981cdd7202 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,6 +10,7 @@ from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -225,12 +226,14 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -251,6 +254,7 @@ def __init__( @property def prompt(self) -> Optional[str]: return self.inputs.get("prompt") + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> List[int]: @@ -264,6 +268,11 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + def get_output_text_to_return(self, buffer_length: int): # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() @@ -426,6 +435,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -440,6 +450,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq @property @@ -464,6 +475,16 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -637,9 +658,10 @@ def __init__( lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional["MultiModalData"] = None, + multi_modal_data: Optional[MultiModalData] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -648,6 +670,7 @@ def __init__( self.block_tables = block_tables self.pooling_params = pooling_params self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @@ -672,6 +695,16 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d539f56937be1..43c16e794519a 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -7,7 +7,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, PromptAdapterConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -34,6 +34,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -46,6 +47,7 @@ def __init__( self.device_config = device_config self.cache_config = cache_config self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker @@ -87,7 +89,9 @@ def load_model(self) -> None: lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + cache_config=self.cache_config, + prompt_adapter_config = self.prompt_adapter_config + ) def _prepare_prompt( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3ee394f9912e9..ce7f0d06369c9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,7 +7,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, PromptAdapterConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -133,6 +133,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -146,6 +147,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -164,6 +166,7 @@ def __init__( lora_config=self.lora_config, vision_language_config=self.vision_language_config, kv_cache_dtype=kv_cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 476e9ba3bb463..443e9333a51f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,13 +11,17 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, PromptAdapterConfig) from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -53,6 +57,8 @@ class ModelInput(NamedTuple): num_prefill_tokens: int num_decode_tokens: int num_prefills: int + prompt_adapter_mapping: Optional[PromptAdapterMapping] + prompt_adapter_requests: Set[PromptAdapterRequest] @classmethod def empty(cls, device): @@ -69,6 +75,8 @@ def empty(cls, device): num_prefill_tokens=0, num_decode_tokens=0, num_prefills=0, + prompt_adapter_mapping=None, + prompt_adapter_requests=set(), ) @@ -86,6 +94,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config @@ -96,7 +105,7 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker self.vision_language_config = vision_language_config - + self.prompt_adapter_config = prompt_adapter_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -142,6 +151,7 @@ def __init__( self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -153,7 +163,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config, + cache_config=self.cache_config ) self.model_memory_usage = m.consumed_memory @@ -181,7 +191,15 @@ def load_model(self) -> None: max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) - + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, self.prompt_adapter_config) + self.model = self.prompt_adapter_manager\ + .create_prompt_adapter_manager(self.model) + if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated @@ -259,6 +277,9 @@ def _prepare_model_input( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() seq_lens: List[int] = [] prefill_seq_lens: List[int] = [] @@ -418,7 +439,8 @@ def _prepare_model_input( input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id - + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if is_prompt: assert len(seq_ids) == 1 num_prefills += 1 @@ -434,7 +456,7 @@ def _prepare_model_input( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - + lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * @@ -442,17 +464,20 @@ def _prepare_model_input( and seq_group_metadata.sampling_params.prompt_logprobs is not None else 1)) - mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) + if prompt_adapter_id > 0: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + prompt_adapter_index_mapping += [prompt_adapter_id] * query_len + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not @@ -515,6 +540,8 @@ def _prepare_model_input( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) + prompt_adapter_index_mapping.append(0) + batch_size = graph_batch_size num_decode_tokens = batch_size @@ -637,6 +664,14 @@ def _prepare_model_input( else: lora_mapping = None + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + multi_modal_kwargs = { k: torch.cat(v, dim=0).to(self.device) for k, v in multi_modal_kwargs_list.items() @@ -655,13 +690,15 @@ def _prepare_model_input( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests, ) def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor], Set[PromptAdapterRequest], PromptAdapterMapping]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -678,7 +715,10 @@ def prepare_input_tensors( num_prefill_tokens, num_decode_tokens, num_prefills, + prompt_adapter_mapping, + prompt_adapter_requests ) = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) @@ -695,6 +735,8 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, + "prompt_adapter_requests": prompt_adapter_requests, + "prompt_adapter_mapping": prompt_adapter_mapping, } if attn_metadata: metadata_dict.update(attn_metadata.asdict_zerocopy()) @@ -708,6 +750,9 @@ def prepare_input_tensors( lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + prompt_adapter_mapping = metadata_dict.pop("prompt_adapter_mapping") + prompt_adapter_requests = metadata_dict.pop("prompt_adapter_requests") + if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -722,7 +767,7 @@ def prepare_input_tensors( return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs) + multi_modal_kwargs, prompt_adapter_requests, prompt_adapter_mapping) @torch.inference_mode() def execute_model( @@ -731,12 +776,16 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs + lora_requests, lora_mapping, multi_modal_kwargs, + prompt_adapter_requests, prompt_adapter_mapping ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) + if self.prompt_adapter_config: + self.set_active_prompt_adapters(prompt_adapter_requests, prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata decode_meta = attn_metadata.decode_metadata @@ -871,6 +920,37 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_prompt_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping + ) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_prompt_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_prompt_adapter( + prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_lora(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_prompt_adapters() + @torch.inference_mode() def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """Cuda graph capture a model. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7a378a862d0c0..5325ff54d64f8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,12 +8,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, PromptAdapterConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput @@ -45,6 +46,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -57,6 +59,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -83,6 +86,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + prompt_adapter_config = prompt_adapter_config ) # Uninitialized cache engine. Will be initialized by # initialize_cache. @@ -326,6 +330,17 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.prompt_adapter_manager.add_prompt_adapter( + prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.prompt_adapter_manager.remove_lora(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.prompt_adapter_manager.list_prompt_adapters() + @property def max_model_len(self) -> int: return self.model_config.max_model_len From 96b4a1a3b36eec4b74901438cb4799ee498dd3f1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 06:58:58 -0400 Subject: [PATCH 02/80] Run yapf and ruff --- format.sh | 6 +- tests/worker/test_model_runner.py | 4 +- vllm/adapter_commons/models.py | 40 +++++++++--- vllm/adapter_commons/request.py | 7 ++- vllm/adapter_commons/worker_manager.py | 17 ++++- vllm/config.py | 1 + vllm/core/scheduler.py | 3 +- vllm/engine/arg_utils.py | 14 ++--- vllm/engine/async_llm_engine.py | 19 +++--- vllm/engine/llm_engine.py | 57 ++++++++--------- vllm/entrypoints/llm.py | 31 ++++----- vllm/executor/cpu_executor.py | 2 +- vllm/executor/executor_base.py | 26 ++++---- vllm/executor/gpu_executor.py | 2 +- vllm/lora/models.py | 12 ++-- vllm/lora/request.py | 8 ++- vllm/lora/worker_manager.py | 14 ++--- vllm/model_executor/models/baichuan.py | 2 +- vllm/model_executor/models/bloom.py | 2 +- vllm/model_executor/models/gpt_bigcode.py | 2 +- vllm/model_executor/models/llama.py | 2 +- vllm/model_executor/models/mixtral.py | 2 +- vllm/prompt_adapter/layers.py | 8 ++- vllm/prompt_adapter/models.py | 77 +++++++++++------------ vllm/prompt_adapter/request.py | 10 +-- vllm/prompt_adapter/worker_manager.py | 30 +++++---- vllm/sequence.py | 20 +++--- vllm/worker/cpu_model_runner.py | 5 +- vllm/worker/cpu_worker.py | 4 +- vllm/worker/embedding_model_runner.py | 52 ++++++++------- vllm/worker/model_runner.py | 77 ++++++++++------------- vllm/worker/worker.py | 18 +++--- 32 files changed, 306 insertions(+), 268 deletions(-) diff --git a/format.sh b/format.sh index 6057b69af8ce8..291de31ef957f 100755 --- a/format.sh +++ b/format.sh @@ -26,7 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) -CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') +# CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -41,7 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" +# tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' @@ -112,6 +112,8 @@ mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml + # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 87391e09d34cd..f8ddb7618b9f4 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -353,8 +353,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): seq_group_metadata_list.append(seq_group_metadata) decode_metadata_list.append(seq_group_metadata) - (input_tokens, input_positions, attn_metadata, _, _, _, - _, _, _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, _, _, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index 630aeb9ba0101..403f5115ebc34 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,13 +1,16 @@ -import os -import json -from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, TypeVar -import torch +from abc import abstractproperty +from typing import Any, Callable, Dict, Hashable, Optional, TypeVar + from torch import nn -from vllm.utils import LRUCache + from vllm.logger import init_logger +from vllm.utils import LRUCache + logger = init_logger(__name__) + class AdapterModel: + def __init__(self, model_id=None): self.id = model_id @@ -17,9 +20,14 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): # Load weights or embeddings from local checkpoint raise NotImplementedError("Subclasses must implement this method.") + T = TypeVar('T') + + class AdapterLRUCache(LRUCache[T]): - def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], None]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], + None]): super().__init__(capacity) self.deactivate_fn = deactivate_fn @@ -28,7 +36,9 @@ def _on_remove(self, key: Hashable, value: T): self.deactivate_fn(key) return super()._on_remove(key, value) + class AdapterModelManager: + def __init__( self, model: nn.Module, @@ -42,10 +52,19 @@ def __init__( # Dict instead of a Set for compatibility with LRUCache. self._active_adapters: Dict[int, None] = {} self.adapter_type = 'Adapter' + self._last_mapping = None def __len__(self) -> int: return len(self._registered_adapters) + @abstractproperty + def adapter_slots(self): + ... + + @abstractproperty + def capacity(self): + ... + def _deactivate_adapter(self, adapter_id: int): raise NotImplementedError("Subclasses must implement this method.") @@ -65,7 +84,7 @@ def _add_adapter(self, adapter: Any): def add_adapter(self, adapter: Any) -> bool: if adapter.id not in self._registered_adapters: if len(self._registered_adapters) >= self.capacity: - raise RuntimeError("No free "+self.adapter_type+" slots.") + raise RuntimeError("No free " + self.adapter_type + " slots.") self._add_adapter(adapter) return True return False @@ -74,7 +93,10 @@ def set_adapter_mapping(self, mapping: Any) -> None: if self._last_mapping != mapping: self._set_adapter_mapping(mapping) self._last_mapping = mapping - + + def _set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError("Subclasses must implement this method.") + def remove_adapter(self, adapter_id: int) -> bool: self.deactivate_adapter(adapter_id) return bool(self._registered_adapters.pop(adapter_id, None)) @@ -83,4 +105,4 @@ def list_adapters(self) -> Dict[int, Any]: return dict(self._registered_adapters) def get_adapter(self, adapter_id: int) -> Optional[Any]: - return self._registered_adapters.get(adapter_id, None) \ No newline at end of file + return self._registered_adapters.get(adapter_id, None) diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index 26d59babdb987..7e50ae184ee22 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,19 +1,20 @@ from dataclasses import dataclass + @dataclass class AdapterRequest: """ Base class for adapter requests. """ + adapter_id: int def __post_init__(self): if self.adapter_id < 1: - raise ValueError( - f"id must be > 0, got {self.adapter_id}") + raise ValueError(f"id must be > 0, got {self.adapter_id}") def __eq__(self, value: object) -> bool: return isinstance( value, self.__class__) and self.adapter_id == value.adapter_id def __hash__(self) -> int: - return hash(self.adapter_id) \ No newline at end of file + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index ee113f1e2080b..e6a4173efba4e 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -1,10 +1,16 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, List, Optional, Set, Type, Dict +from typing import Any, Optional, Set + import torch +from vllm.adapter_commons.models import AdapterModelManager + + class AbstractWorkerManager(ABC): + def __init__(self, device: torch.device): self.device = device + self._model_manager: AdapterModelManager = None @abstractproperty def is_enabled(self) -> bool: @@ -14,7 +20,8 @@ def is_enabled(self) -> bool: def create_manager(self, model: torch.nn.Module) -> Any: ... - def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: self._apply_adapters(requests) self._model_manager.set_adapter_mapping(mapping) @@ -22,6 +29,10 @@ def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> Non def add_dummy_adapter(self, request: Any) -> bool: ... + @abstractmethod + def _load_adapter(self, request: Any) -> Any: + ... + def add_adapter(self, adapter_request: Any) -> bool: if adapter_request.adapter_id in self.list_adapters(): return False @@ -59,4 +70,4 @@ def remove_all_adapters(self): self._model_manager.remove_all_adapters() def list_adapters(self) -> Set[int]: - return set(self._model_manager.list_adapters()) \ No newline at end of file + return set(self._model_manager.list_adapters()) diff --git a/vllm/config.py b/vllm/config.py index 3942b11c35c7f..21d42c5673139 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1106,6 +1106,7 @@ def __post_init__(self): if self.max_cpu_prompt_adapters is None: self.max_cpu_prompt_adapters = self.max_prompt_adapters + @dataclass class VisionLanguageConfig: """Configs the input data format and how models should run for diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4c5b244cb1e42..be9992b3deaa6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -157,7 +157,8 @@ def _sort_by_lora_ids(self): def _sort_by_prompt_adapter_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, - key=lambda g: (g.seq_group.prompt_adapter_id, g.seq_group.request_id)) + key=lambda g: + (g.seq_group.prompt_adapter_id, g.seq_group.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4f8006e914a3a..d7b104891c395 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,8 +7,9 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig, VisionLanguageConfig, PromptAdapterConfig) + ParallelConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig, + VisionLanguageConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple @@ -503,9 +504,6 @@ def add_cli_args( 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) - # parser.add_argument('--enable-prompt-adapter', - # action='store_true', - # help='If True, enable handling of PromptAdapters.') parser.add_argument('--max-prompt-adapters', type=int, default=EngineArgs.max_prompt_adapters, @@ -730,10 +728,10 @@ def create_engine_config(self, ) -> EngineConfig: download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, ) - + prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters) #if self.enable_prompt_adapter else None - + max_prompt_adapters=self.max_prompt_adapters) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 943402c865bd2..171d4670087c3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,6 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -279,12 +280,13 @@ async def process_model_inputs_async( multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -301,7 +303,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) async def check_health_async(self) -> None: self.model_executor.check_health() @@ -545,6 +547,7 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -598,7 +601,7 @@ async def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) return stream diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fae815f4e3a70..e19514f73eda1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -9,8 +9,8 @@ import vllm from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -24,10 +24,10 @@ from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, @@ -456,8 +456,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -465,8 +464,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -495,8 +493,9 @@ def process_model_inputs( prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: - prompt_token_ids = [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - + prompt_token_ids + prompt_token_ids = \ + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), @@ -558,10 +557,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs(request_id=request_id, - inputs=inputs, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + processed_inputs = self.process_model_inputs( + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -569,8 +569,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) def _create_sequence_group_with_sampling( self, @@ -601,12 +600,13 @@ def _create_sequence_group_with_sampling( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) return seq_group @@ -623,12 +623,13 @@ def _create_sequence_group_with_pooling( # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1002,6 +1003,6 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> List[int]: return self.model_executor.list_prompt_adapters() - + def check_health(self) -> None: self.model_executor.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ac20d7b13c3b5..d1762de87bd8d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -11,9 +11,9 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext @@ -301,8 +301,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -395,6 +394,7 @@ def encode( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -443,6 +443,7 @@ def encode( inputs=inputs, params=pooling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -525,22 +526,22 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - inputs, - params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self.llm_engine.add_request( + request_id, + inputs, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 4eb8fb89c399f..007d90fefb0fb 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -98,7 +98,7 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: return self.driver_worker.list_prompt_adapters() - + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 145c451bd2f8f..edbb89b125717 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,8 +2,9 @@ from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -17,19 +18,14 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig] - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index c4bfaaa12bb5c..31e567eb3d0c7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -117,7 +117,7 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: return self.driver_worker.list_prompt_adapters() - + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3b09d75f3daf5..2c80aacc60c09 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -10,6 +10,8 @@ import torch from torch import nn +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -18,8 +20,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.utils import LRUCache, is_pin_memory_available -from vllm.adapter_commons.models import AdapterModel, AdapterLRUCache, AdapterModelManager +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -489,7 +490,7 @@ def _deactivate_adapter(self): @property def deactivate_lora(self): return self.deactivate_adapter - + def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -563,7 +564,7 @@ def list_loras(self): @property def get_lora(self): return self.get_adapter - + def remove_all_loras(self): """Remove all LoRAModels from the manager.""" self._registered_adapters.clear() @@ -719,7 +720,8 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: class LoRALRUCache(AdapterLRUCache[LoRAModel]): - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], None]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + None]): super().__init__(capacity, deactivate_lora_fn) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index eb301311a7c10..2d10d037760e2 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from typing import Optional + from vllm.adapter_commons.request import AdapterRequest + @dataclass class LoRARequest(AdapterRequest): """ @@ -25,11 +27,11 @@ class LoRARequest(AdapterRequest): @property def adapter_id(self): return self.lora_int_id - + @property def name(self): return self.lora_name - + @property def local_path(self): - return self.lora_local_path \ No newline at end of file + return self.lora_local_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index d247c202d472b..26b420f5c9b5d 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,19 +1,18 @@ -from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch +from vllm.adapter_commons.worker_manager import AbstractWorkerManager from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest -from vllm.adapter_commons.worker_manager import AbstractWorkerManager logger = init_logger(__name__) + class WorkerLoRAManager(AbstractWorkerManager): """WorkerLoRAManager that manages LoRA models on the worker side. @@ -39,14 +38,14 @@ def __init__( self.embedding_padding_modules = embedding_padding_modules self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_batched_tokens = max_num_batched_tokens self.vocab_size = vocab_size self.lora_config = lora_config self.max_position_embeddings = max_position_embeddings super().__init__(device) # Lazily initialized by create_lora_manager. self._lora_manager: LoRAModelManager - + @contextmanager def dummy_lora_cache(self): """Use this context manager to reuse the dummy lora model @@ -54,7 +53,7 @@ def dummy_lora_cache(self): self._cached_dummy_lora = None yield self._cached_dummy_lora = False - + @property def is_enabled(self) -> bool: return True @@ -172,8 +171,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _manager_cls: Type[ - LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 08599941507cc..7608c4caeff40 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -43,8 +43,8 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.prompt_adapter.layers import apply_prompt_adapter +from vllm.sequence import SamplerOutput def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 29d935d723af4..78c0d7756a610 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,8 +39,8 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.prompt_adapter.layers import apply_prompt_adapter +from vllm.sequence import SamplerOutput def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b50845f236151..dd1a7750a49ee 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,8 +39,8 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.prompt_adapter.layers import apply_prompt_adapter +from vllm.sequence import SamplerOutput class GPTBigCodeAttention(nn.Module): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9618b74f282f8..f9088449e007c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -46,9 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once -from vllm.prompt_adapter.layers import apply_prompt_adapter class LlamaMLP(nn.Module): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b7c6f3db57405..8948e1cccb4ac 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,9 +51,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once -from vllm.prompt_adapter.layers import apply_prompt_adapter class MixtralMoE(nn.Module): diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 799a2e38feecc..b06b591b061b5 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from typing import Tuple + import torch + @dataclass class PromptAdapterMapping: # Per every token in input_ids: @@ -13,9 +15,11 @@ def __post_init__(self): self.index_mapping = tuple(self.index_mapping) self.prompt_mapping = tuple(self.prompt_mapping) -def apply_prompt_adapter(instance, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + +def apply_prompt_adapter(instance, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: if hasattr(instance, 'prefix_encoder'): soft_prompt = instance.prefix_encoder.prompt_embedding indices = (positions < soft_prompt.shape[0]) hidden_states[indices] = soft_prompt[positions[indices]] - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 84eb3e8ee03e1..ac6554a53836b 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -1,13 +1,15 @@ import logging -from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Tuple -from peft.utils import load_peft_weights -import torch import math +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from peft.utils import load_peft_weights from torch import nn -from vllm.adapter_commons.models import AdapterModel, AdapterLRUCache, AdapterModelManager -from vllm.prompt_adapter.layers import PromptAdapterMapping + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) from vllm.config import PromptAdapterConfig -from vllm.utils import LRUCache +from vllm.prompt_adapter.layers import PromptAdapterMapping logger = logging.getLogger(__name__) @@ -19,24 +21,23 @@ def get_prompt_adapter_id(): _GLOBAL_PROMPT_ADAPTER_ID += 1 return _GLOBAL_PROMPT_ADAPTER_ID + def convert_mapping( mapping: PromptAdapterMapping, - prompt_adapter_index_to_id: List[Optional[int]], - max_prompt_adapters: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: + prompt_adapter_index_to_id: List[Optional[int]], max_prompt_adapters: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: """Converts PromptAdapterMapping to index tensors. Args: - mapping: PromptAdapterMapping mapping rows in a batch to PromptAdapter ids. - prompt_adapter_index_to_id: List mapping PromptAdapter ids to PromptAdapter indices. + mapping: PromptAdapterMapping mapping rows in a batch to ids. + prompt_adapter_index_to_id: List mapping PromptAdapter ids to indices. max_prompt_adapters: Maximum number of PromptAdapters. Returns: A tuple of tensors: base_indices: Tensor of shape [batch_size] mapping batch rows to PromptAdapter indices. sampler_indices: Tensor of shape [batch_size] mapping requests to - PromptAdapter indices for sampler. For generation, this will be the + PromptAdapter indices for sampler. For generation, this will be same as base_indicies. For prefill, this will map requests to PromptAdapter indices. sampler_indices_padded: Tensor of shape [batch_size] mapping @@ -56,10 +57,10 @@ def convert_mapping( prompt_adapter_idx = None for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - prompt_adapter_idx = (prompt_adapter_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) + prompt_adapter_idx = (prompt_adapter_index_to_id.index( + index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1) prompt_adapter_indices[i] = prompt_adapter_idx - + indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, prompt_adapter_indices ] @@ -70,7 +71,8 @@ def convert_mapping( base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_prompt_adapters - 1 + sampler_indices_padded[sampler_indices_padded == + -1] = max_prompt_adapters - 1 sampler_indices_padded = ( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + @@ -80,11 +82,11 @@ def convert_mapping( base_indices.shape[-1], sampler_indices.shape[-1], sampler_indices_padded.shape[-1] ] - return (base_indices, sampler_indices, sampler_indices_padded, - indices_len) + return (base_indices, sampler_indices, sampler_indices_padded, indices_len) class PromptAdapterModel(AdapterModel): + def __init__(self, prompt_adapter_id=None, num_virtual_tokens=None, @@ -128,10 +130,9 @@ def __init__( self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.prompt_adapter_config = prompt_adapter_config - self._last_mapping = None self.model.prompt_adapter_manager = self self.adapter_type = 'PromptAdapter' - + self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") @@ -201,19 +202,11 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int): def _deactivate_adapter(self): return self._deactivate_prompt_adapter - def deactivate_prompt_adapter(self, prompt_adapter_id: int) -> bool: - """Remove a prompt_adapter from a GPU buffer.""" - if prompt_adapter_id in self._active_adapters: - self._deactivate_prompt_adapter(prompt_adapter_id) - self._active_adapters.pop(prompt_adapter_id) - return True - return False - @property def deactivate_prompt_adapter(self): return self.deactivate_adapter - def _add_prompt_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + def _add_prompt_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter @property @@ -228,9 +221,11 @@ def add_prompt_adapter(self): def remove_prompt_adapter(self): return self.remove_adapter - def _set_prompt_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + def _set_prompt_adapter_mapping(self, + mapping: PromptAdapterMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, - indices_len) = convert_mapping(mapping, self.prompt_adapter_index_to_id, + indices_len) = convert_mapping(mapping, + self.prompt_adapter_index_to_id, self.prompt_adapter_slots + 1) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) @@ -246,7 +241,7 @@ def set_prompt_adapter_mapping(self): @property def _set_adapter_mapping(self): return self._set_prompt_adapter_mapping - + @property def list_prompt_adapters(self): return self.list_adapters @@ -254,8 +249,8 @@ def list_prompt_adapters(self): @property def get_prompt_adapter(self): return self.get_adapter - - def remove_all_prompt_adapters(self) -> bool: + + def remove_all_prompt_adapters(self): """Remove all PromptAdapterModel from the manager.""" self._registered_adapters.clear() self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots @@ -267,10 +262,11 @@ def remove_all_adapters(self): class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): - def __init__(self, capacity: int, - deactivate_prompt_adapter_fn: Callable[[Hashable], None]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], None]): super().__init__(capacity, deactivate_prompt_adapter_fn) - + class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): """A model manager that manages multiple prompt_adapters with LRU cache.""" @@ -337,5 +333,6 @@ def create_prompt_adapter_manager( """Create a PromptAdapterModel for a given model.""" prompt_adapter_manager = prompt_adapter_manager_cls( model=model, max_num_seqs=max_num_seqs, \ - max_num_batched_tokens=max_num_batched_tokens, prompt_adapter_config=prompt_adapter_config, **kwargs) - return prompt_adapter_manager \ No newline at end of file + max_num_batched_tokens=max_num_batched_tokens, \ + prompt_adapter_config=prompt_adapter_config, **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py index c91cbb86d223c..31ba38c420583 100644 --- a/vllm/prompt_adapter/request.py +++ b/vllm/prompt_adapter/request.py @@ -1,12 +1,14 @@ from dataclasses import dataclass + from vllm.adapter_commons.request import AdapterRequest + @dataclass class PromptAdapterRequest(AdapterRequest): """ Request for a Prompt adapter. """ - + prompt_adapter_name: str prompt_adapter_id: int prompt_adapter_local_path: str @@ -16,11 +18,11 @@ class PromptAdapterRequest(AdapterRequest): @property def adapter_id(self): return self.prompt_adapter_id - + @property def name(self): return self.prompt_adapter_name - + @property def local_path(self): - return self.prompt_adapter_local_path \ No newline at end of file + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index e87ecc5aeab80..a79314d481d60 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -1,19 +1,19 @@ import logging -from abc import ABC, abstractmethod, abstractproperty -from typing import Any, List, Optional, Set, Type, Dict -import math +from typing import Any, Optional, Set, Type + import torch +from vllm.adapter_commons.worker_manager import AbstractWorkerManager from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, PromptAdapterModel, PromptAdapterModelManager, create_prompt_adapter_manager) from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.adapter_commons.worker_manager import AbstractWorkerManager + logger = logging.getLogger(__name__) + class WorkerPromptAdapterManager(AbstractWorkerManager): """WorkerPromptAdapterManager that manages prompt_adapter models on the worker side. @@ -22,8 +22,7 @@ class WorkerPromptAdapterManager(AbstractWorkerManager): loaded (unless they are already loaded), and every other prompt_adapter will be unloaded.""" - _manager_cls: Type[ - PromptAdapterModelManager] = PromptAdapterModelManager + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager def __init__( self, @@ -36,7 +35,7 @@ def __init__( self._prompt_adapter_manager: Optional[ PromptAdapterModelManager] = None self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_batched_tokens = max_num_batched_tokens self._prompt_adapter_model_cls = prompt_adapter_model_cls self.prompt_adapter_config = prompt_adapter_config super().__init__(device) @@ -56,8 +55,7 @@ def create_prompt_adapter_manager( prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_manager_cls=self._manager_cls, ) - self._prompt_adapter_manager: PromptAdapterModelManager \ - = prompt_adapter_manager + self._prompt_adapter_manager = prompt_adapter_manager return prompt_adapter_manager.model @property @@ -83,7 +81,7 @@ def _load_prompt_adapter( def add_dummy_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - pass + return True @property def add_dummy_adapter(self): @@ -92,7 +90,7 @@ def add_dummy_adapter(self): @property def create_manager(self): return self.create_prompt_adapter_manager - + @property def _load_adapter(self): return self._load_prompt_adapter @@ -120,7 +118,7 @@ def list_prompt_adapters(self): @property def _apply_prompt_adapters(self): return self._apply_adapters - + class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): """WorkerPromptAdapterManager that manages @@ -147,7 +145,7 @@ def create_prompt_adapter_manager( self._prompt_adapter_manager: \ LRUCachePromptAdapterModelManager = prompt_adapter_manager return prompt_adapter_manager.model - + def _apply_adapters( self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: prompt_adapters_map = { @@ -165,8 +163,8 @@ def _apply_adapters( for prompt_adapter in prompt_adapters_map.values(): self.add_prompt_adapter(prompt_adapter) - def add_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: if prompt_adapter_request.prompt_adapter_id not in \ self.list_prompt_adapters(): # Remove before we load the new prompt_adapter to save memory diff --git a/vllm/sequence.py b/vllm/sequence.py index eee981cdd7202..6a761a4b47403 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,8 +10,8 @@ from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -220,13 +220,13 @@ class Sequence: """ def __init__( - self, - seq_id: int, - inputs: LLMInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + seq_id: int, + inputs: LLMInputs, + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -484,7 +484,7 @@ def prompt_adapter_id(self) -> int: def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 - + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -704,7 +704,7 @@ def prompt_adapter_id(self) -> int: def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 - + @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 43c16e794519a..84951e8ab8363 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -6,8 +6,8 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -90,7 +90,6 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config, - prompt_adapter_config = self.prompt_adapter_config ) def _prepare_prompt( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ce7f0d06369c9..d1801cc8950bb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f9..37f94dc28a5b0 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -4,14 +4,16 @@ from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner @@ -32,6 +34,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, ): super().__init__(model_config, parallel_config, @@ -42,7 +45,8 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + prompt_adapter_config=prompt_adapter_config) @torch.inference_mode() def execute_model( @@ -51,12 +55,17 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input + lora_requests, lora_mapping, multi_modal_input, + prompt_adapter_requests, prompt_adapter_mapping ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) + if self.prompt_adapter_config: + self.set_active_prompt_adapters(prompt_adapter_requests, + prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata decode_meta = attn_metadata.decode_metadata @@ -90,24 +99,16 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + Set[LoRARequest], LoRAMapping, torch.Tensor, + Set[PromptAdapterRequest], PromptAdapterMapping]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, seq_lens, _, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, + num_prefill_tokens, num_decode_tokens, num_prefills, + prompt_adapter_mapping, prompt_adapter_requests + ) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, seq_lens) @@ -117,11 +118,13 @@ def prepare_input_tensors( "input_positions": input_positions, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, + "multi_modal_input": multi_modal_input, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, + "prompt_adapter_requests": prompt_adapter_requests, + "prompt_adapter_mapping": prompt_adapter_mapping, } if attn_metadata: metadata_dict.update(attn_metadata.asdict_zerocopy()) @@ -132,7 +135,11 @@ def prepare_input_tensors( input_positions = metadata_dict.pop("input_positions") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + multi_modal_input = metadata_dict.pop("multi_modal_input") + prompt_adapter_mapping = metadata_dict.pop( + "prompt_adapter_mapping") + prompt_adapter_requests = metadata_dict.pop( + "prompt_adapter_requests") if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -143,7 +150,8 @@ def prepare_input_tensors( prompt_lens=None) return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs) + lora_requests, lora_mapping, multi_modal_input, + prompt_adapter_requests, prompt_adapter_mapping) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 443e9333a51f6..a5322ea86134d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,14 +10,17 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -163,8 +166,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config - ) + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -191,15 +193,15 @@ def load_model(self) -> None: max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) - + if self.prompt_adapter_config: self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.device, self.prompt_adapter_config) + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) self.model = self.prompt_adapter_manager\ .create_prompt_adapter_manager(self.model) - + if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated @@ -440,7 +442,7 @@ def _prepare_model_input( input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id prompt_adapter_id = seq_group_metadata.prompt_adapter_id - + if is_prompt: assert len(seq_ids) == 1 num_prefills += 1 @@ -456,7 +458,7 @@ def _prepare_model_input( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - + lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * @@ -467,17 +469,13 @@ def _prepare_model_input( if prompt_adapter_id > 0: prompt_adapter_requests.add( seq_group_metadata.prompt_adapter_request) - + prompt_adapter_index_mapping += [prompt_adapter_id] * query_len prompt_adapter_prompt_mapping.extend( [prompt_adapter_id] * (query_len if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not @@ -541,7 +539,7 @@ def _prepare_model_input( block_tables.append([]) lora_index_mapping.append(0) prompt_adapter_index_mapping.append(0) - + batch_size = graph_batch_size num_decode_tokens = batch_size @@ -698,27 +696,17 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor], Set[PromptAdapterRequest], PromptAdapterMapping]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor], + Set[PromptAdapterRequest], PromptAdapterMapping]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - query_lens, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - prompt_adapter_mapping, - prompt_adapter_requests - ) = self._prepare_model_input(seq_group_metadata_list) - + (input_tokens, input_positions, attn_metadata, seq_lens, + query_lens, lora_mapping, lora_requests, multi_modal_kwargs, + slot_mapping, num_prefill_tokens, num_decode_tokens, num_prefills, + prompt_adapter_mapping, prompt_adapter_requests + ) = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) @@ -750,9 +738,11 @@ def prepare_input_tensors( lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - prompt_adapter_mapping = metadata_dict.pop("prompt_adapter_mapping") - prompt_adapter_requests = metadata_dict.pop("prompt_adapter_requests") - + prompt_adapter_mapping = metadata_dict.pop( + "prompt_adapter_mapping") + prompt_adapter_requests = metadata_dict.pop( + "prompt_adapter_requests") + if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -767,7 +757,8 @@ def prepare_input_tensors( return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs, prompt_adapter_requests, prompt_adapter_mapping) + multi_modal_kwargs, prompt_adapter_requests, + prompt_adapter_mapping) @torch.inference_mode() def execute_model( @@ -784,7 +775,8 @@ def execute_model( self.set_active_loras(lora_requests, lora_mapping) if self.prompt_adapter_config: - self.set_active_prompt_adapters(prompt_adapter_requests, prompt_adapter_mapping) + self.set_active_prompt_adapters(prompt_adapter_requests, + prompt_adapter_mapping) # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata @@ -924,11 +916,10 @@ def remove_all_prompt_adapters(self): if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") self.prompt_adapter_manager.remove_all_prompt_adapters() - + def set_active_prompt_adapters( self, prompt_adapter_requests: Set[PromptAdapterRequest], - prompt_adapter_mapping: PromptAdapterMapping - ) -> None: + prompt_adapter_mapping: PromptAdapterMapping) -> None: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") self.prompt_adapter_manager.set_active_prompt_adapters( @@ -940,7 +931,7 @@ def add_prompt_adapter( raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.add_prompt_adapter( prompt_adapter_request) - + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5325ff54d64f8..91277ad1891df 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,13 +7,15 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -86,8 +88,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, - prompt_adapter_config = prompt_adapter_config - ) + prompt_adapter_config=prompt_adapter_config) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CacheEngine @@ -332,14 +333,13 @@ def list_loras(self) -> Set[int]: def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.prompt_adapter_manager.add_prompt_adapter( - prompt_adapter_request) - + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.prompt_adapter_manager.remove_lora(prompt_adapter_id) + return self.model_runner.remove_lora(prompt_adapter_id) def list_prompt_adapters(self) -> Set[int]: - return self.prompt_adapter_manager.list_prompt_adapters() + return self.model_runner.list_prompt_adapters() @property def max_model_len(self) -> int: From 313127390b66a917773862480f8d686e65ee9865 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 08:07:53 -0400 Subject: [PATCH 03/80] Multimodal fix --- vllm/sequence.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6a761a4b47403..87e4cb8b60358 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -254,7 +254,6 @@ def __init__( @property def prompt(self) -> Optional[str]: return self.inputs.get("prompt") - return self.inputs.get("prompt") @property def prompt_token_ids(self) -> List[int]: @@ -658,7 +657,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional[MultiModalData] = None, + multi_modal_data: Optional["MultiModalData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, From e9ff38bf662d1b63cf782ffa8c6054dc4522ab24 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 08:29:25 -0400 Subject: [PATCH 04/80] correctness update --- vllm/adapter_commons/worker_manager.py | 5 ++++- vllm/prompt_adapter/layers.py | 2 +- vllm/prompt_adapter/models.py | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index e6a4173efba4e..5a959b9b58333 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -10,8 +10,11 @@ class AbstractWorkerManager(ABC): def __init__(self, device: torch.device): self.device = device - self._model_manager: AdapterModelManager = None + @abstractproperty + def _model_manager(self): + ... + @abstractproperty def is_enabled(self) -> bool: ... diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index b06b591b061b5..42dc770df0857 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -18,7 +18,7 @@ def __post_init__(self): def apply_prompt_adapter(instance, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - if hasattr(instance, 'prefix_encoder'): + if instance.prefix_encoder != None: soft_prompt = instance.prefix_encoder.prompt_embedding indices = (positions < soft_prompt.shape[0]) hidden_states[indices] = soft_prompt[positions[indices]] diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index ac6554a53836b..8b07a286bbfd5 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -194,7 +194,8 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int): self.prompt_adapter_index_to_id[index] = None for module_name, module in self.model.named_modules(): if 'Model' in (module.__class__.__name__): - del module.prefix_encoder + module.prefix_encoder = None + break except ValueError: pass From 9f0a8ae92658ff6d9cb4a8846790c79d9a13d5e2 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 08:31:24 -0400 Subject: [PATCH 05/80] formatting --- vllm/adapter_commons/worker_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index 5a959b9b58333..80584620343c4 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -3,8 +3,6 @@ import torch -from vllm.adapter_commons.models import AdapterModelManager - class AbstractWorkerManager(ABC): @@ -14,7 +12,7 @@ def __init__(self, device: torch.device): @abstractproperty def _model_manager(self): ... - + @abstractproperty def is_enabled(self) -> bool: ... From c2937d1fe87991d01be3bd63dc35464cbf97b74d Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 08:35:57 -0400 Subject: [PATCH 06/80] formatting --- vllm/prompt_adapter/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 42dc770df0857..14e6e25cab893 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -18,7 +18,7 @@ def __post_init__(self): def apply_prompt_adapter(instance, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - if instance.prefix_encoder != None: + if instance.prefix_encoder is not None: soft_prompt = instance.prefix_encoder.prompt_embedding indices = (positions < soft_prompt.shape[0]) hidden_states[indices] = soft_prompt[positions[indices]] From e43e89bcbf2d77a789488478c679895f8ef0a52a Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 09:50:38 -0400 Subject: [PATCH 07/80] reverting to hasattr --- vllm/prompt_adapter/layers.py | 2 +- vllm/prompt_adapter/models.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 14e6e25cab893..b06b591b061b5 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -18,7 +18,7 @@ def __post_init__(self): def apply_prompt_adapter(instance, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - if instance.prefix_encoder is not None: + if hasattr(instance, 'prefix_encoder'): soft_prompt = instance.prefix_encoder.prompt_embedding indices = (positions < soft_prompt.shape[0]) hidden_states[indices] = soft_prompt[positions[indices]] diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 8b07a286bbfd5..ac6554a53836b 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -194,8 +194,7 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int): self.prompt_adapter_index_to_id[index] = None for module_name, module in self.model.named_modules(): if 'Model' in (module.__class__.__name__): - module.prefix_encoder = None - break + del module.prefix_encoder except ValueError: pass From a2b4fc35b8412e6429f1f0140507b06dde2fa94a Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 10:36:43 -0400 Subject: [PATCH 08/80] adapter commons fix --- vllm/adapter_commons/request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index 7e50ae184ee22..cf59e29b3c7cd 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -6,7 +6,6 @@ class AdapterRequest: """ Base class for adapter requests. """ - adapter_id: int def __post_init__(self): if self.adapter_id < 1: From 3ebee192854fe1b6b587086e5f51f4fe02e039ca Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 12:22:51 -0400 Subject: [PATCH 09/80] minor fixes --- tests/lora/test_long_context.py | 3 ++- vllm/engine/arg_utils.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index b58145eda2141..7174e2c151517 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -8,6 +8,7 @@ from vllm import SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor.layers.rotary_embedding import ( LinearScalingRotaryEmbedding) @@ -32,7 +33,6 @@ def _create_lora_request(lora_id, long_context_infos): long_context_infos[lora_id]["lora"], 4096 * scaling_factor) - def evaluate_json_response(model_response, golden_response): """Evaluates the model response against the golden response. @@ -96,6 +96,7 @@ def batched_generate( prompt, sampling_param, lora_request=lora_req, + prompt_adapter_request=None ) outputs = llm._run_engine(use_tqdm=True) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d7b104891c395..59b9d125133c8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -65,7 +65,6 @@ class EngineArgs: tokenizer_pool_type: str = "ray" tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False - enable_prompt_adapter: bool = False max_loras: int = 1 max_lora_rank: int = 16 max_prompt_adapters: int = 1 From 629a6845412c9c4526e9b440f22e61a3da286f09 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 12:39:35 -0400 Subject: [PATCH 10/80] formatting --- tests/lora/test_long_context.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 7174e2c151517..50af9b3420833 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -8,7 +8,6 @@ from vllm import SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor.layers.rotary_embedding import ( LinearScalingRotaryEmbedding) @@ -33,6 +32,7 @@ def _create_lora_request(lora_id, long_context_infos): long_context_infos[lora_id]["lora"], 4096 * scaling_factor) + def evaluate_json_response(model_response, golden_response): """Evaluates the model response against the golden response. @@ -92,12 +92,10 @@ def batched_generate( for input in inputs: prompt, sampling_param, lora_req = input # Add requests to the engine and run the engine - llm._validate_and_add_requests( - prompt, - sampling_param, - lora_request=lora_req, - prompt_adapter_request=None - ) + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] From a3ad6ac6ecde723833c9240dad5d865cf0c2fa4a Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 18:14:20 -0400 Subject: [PATCH 11/80] reset_adapter --- vllm/lora/models.py | 4 ++-- vllm/prompt_adapter/models.py | 8 ++++++++ vllm/prompt_adapter/worker_manager.py | 3 +++ vllm/worker/model_runner.py | 12 ++++++++++-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2c80aacc60c09..e5833c2c64a17 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -747,7 +747,7 @@ def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) - def add_adapter(self, lora: LoRAModel) -> bool: + def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " @@ -762,7 +762,7 @@ def add_adapter(self, lora: LoRAModel) -> bool: was_added = False return was_added - def activate_adapter( + def activate_lora( self, lora_id: int, ) -> bool: diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index ac6554a53836b..619b5953685f6 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -157,6 +157,14 @@ def adapter_slots(self) -> int: def capacity(self) -> int: return self.prompt_adapter_config.max_cpu_prompt_adapters + def reset_adapter(self): + try: + for module_name, module in self.model.named_modules(): + if 'Model' in (module.__class__.__name__): + del module.prefix_encoder + except Exception: + pass + def activate_prompt_adapter( self, prompt_adapter_id: int, diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index a79314d481d60..63e94ed2d8c63 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -44,6 +44,9 @@ def __init__( def is_enabled(self) -> bool: return True + def reset_adapter(self): + self._adapter_manager.reset_adapter() + def create_prompt_adapter_manager( self, model: torch.nn.Module, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a5322ea86134d..19173309a6267 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -775,8 +775,11 @@ def execute_model( self.set_active_loras(lora_requests, lora_mapping) if self.prompt_adapter_config: - self.set_active_prompt_adapters(prompt_adapter_requests, - prompt_adapter_mapping) + if len(prompt_adapter_requests) >= 1: + self.set_active_prompt_adapters(prompt_adapter_requests, + prompt_adapter_mapping) + else: + self.reset_adapter() # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata @@ -912,6 +915,11 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() + def reset_adapter(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.reset_adapter() + def remove_all_prompt_adapters(self): if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") From dcd7e88d0ca312f4b6be58a7234ed6ad948feec5 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 3 Jun 2024 18:26:11 -0400 Subject: [PATCH 12/80] bugfix --- vllm/prompt_adapter/worker_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 63e94ed2d8c63..16d13356bd844 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -45,7 +45,7 @@ def is_enabled(self) -> bool: return True def reset_adapter(self): - self._adapter_manager.reset_adapter() + self._prompt_adapter_manager.reset_adapter() def create_prompt_adapter_manager( self, From 647a32dcda673ced41ff0964275ba74de0df1114 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 4 Jun 2024 10:08:13 -0400 Subject: [PATCH 13/80] reset_adapter fix --- vllm/prompt_adapter/models.py | 1 + vllm/prompt_adapter/worker_manager.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 619b5953685f6..7676bfd259ed9 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -159,6 +159,7 @@ def capacity(self) -> int: def reset_adapter(self): try: + self.remove_all_prompt_adapters() for module_name, module in self.model.named_modules(): if 'Model' in (module.__class__.__name__): del module.prefix_encoder diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 16d13356bd844..ea141fd31593c 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -29,7 +29,7 @@ def __init__( max_num_seqs: int, max_num_batched_tokens: int, device: torch.device, - prompt_adapter_config: Type[PromptAdapterConfig], + prompt_adapter_config: PromptAdapterConfig, prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel ): self._prompt_adapter_manager: Optional[ From 90d170c1170bfa4d35b70b8a08fa1278019039e1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 4 Jun 2024 10:12:48 -0400 Subject: [PATCH 14/80] peft dependencies --- requirements-common.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-common.txt b/requirements-common.txt index bf9987e3af014..ca18843e2bde8 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,3 +20,4 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +peft From 0fca8958a4e0c43cd486df5aa77fc4be44d63907 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 4 Jun 2024 10:23:37 -0400 Subject: [PATCH 15/80] fixing llava bug --- vllm/worker/embedding_model_runner.py | 8 ++++---- vllm/worker/model_runner.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 37f94dc28a5b0..6743bb3a6e9bc 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -105,7 +105,7 @@ def prepare_input_tensors( assert seq_group_metadata_list is not None # Prepare input tensors. (input_tokens, input_positions, attn_metadata, seq_lens, _, - lora_mapping, lora_requests, multi_modal_input, slot_mapping, + lora_mapping, lora_requests, multi_modal_kwargs, slot_mapping, num_prefill_tokens, num_decode_tokens, num_prefills, prompt_adapter_mapping, prompt_adapter_requests ) = self._prepare_model_input(seq_group_metadata_list) @@ -118,7 +118,7 @@ def prepare_input_tensors( "input_positions": input_positions, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -135,7 +135,7 @@ def prepare_input_tensors( input_positions = metadata_dict.pop("input_positions") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") prompt_adapter_mapping = metadata_dict.pop( "prompt_adapter_mapping") prompt_adapter_requests = metadata_dict.pop( @@ -150,7 +150,7 @@ def prepare_input_tensors( prompt_lens=None) return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input, + lora_requests, lora_mapping, multi_modal_kwargs, prompt_adapter_requests, prompt_adapter_mapping) def _prepare_pooling( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 19173309a6267..90393bbfbed2d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -466,6 +466,17 @@ def _prepare_model_input( and seq_group_metadata.sampling_params.prompt_logprobs is not None else 1)) + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) + if prompt_adapter_id > 0: prompt_adapter_requests.add( seq_group_metadata.prompt_adapter_request) From d4e531c250127e84568b889e1261d3b4a546303c Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 4 Jun 2024 10:30:03 -0400 Subject: [PATCH 16/80] typing fix --- vllm/worker/embedding_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 6743bb3a6e9bc..204b110bbcfbe 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -99,7 +99,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor, + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor], Set[PromptAdapterRequest], PromptAdapterMapping]: if self.is_driver_worker: assert seq_group_metadata_list is not None From b7f82568b9c526edb1d865205390807b2c303c11 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 4 Jun 2024 17:03:38 -0400 Subject: [PATCH 17/80] async engine update --- vllm/engine/async_llm_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 171d4670087c3..60ccff09abe5d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -611,6 +611,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -678,6 +679,7 @@ async def generate( inputs, sampling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -762,6 +764,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -773,6 +776,7 @@ async def _process_request( params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) try: From 449d988ad220b1abfd52334716e445d7106522a5 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 05:13:39 -0400 Subject: [PATCH 18/80] batchwise processing --- vllm/adapter_commons/layers.py | 13 +++ vllm/lora/layers.py | 12 +- vllm/model_executor/models/baichuan.py | 2 - vllm/model_executor/models/bloom.py | 2 - vllm/model_executor/models/gpt_bigcode.py | 2 - vllm/model_executor/models/llama.py | 2 - vllm/model_executor/models/mixtral.py | 2 - vllm/prompt_adapter/layers.py | 69 ++++++++--- vllm/prompt_adapter/models.py | 136 ++++++---------------- vllm/prompt_adapter/worker_manager.py | 3 - vllm/worker/model_runner.py | 20 ++-- 11 files changed, 108 insertions(+), 155 deletions(-) create mode 100644 vllm/adapter_commons/layers.py diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000000000..251ef7eb2ddee --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Tuple + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e3ab1708c3fdf..41e053434db7b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( @dataclass -class LoRAMapping: - # Per every token in input_ids: - index_mapping: Tuple[int, ...] - # Per sampled token: - prompt_mapping: Tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) +class LoRAMapping(AdapterMapping): + pass class BaseLayerWithLoRA(nn.Module): diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 7608c4caeff40..babb92e7cdcef 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -43,7 +43,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput @@ -279,7 +278,6 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 78c0d7756a610..a29aee4cffb7d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,7 +39,6 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput @@ -252,7 +251,6 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) - hidden_states = apply_prompt_adapter(self, hidden_states, position_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): layer = self.h[i] diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index dd1a7750a49ee..69b75763e9a3d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,7 +39,6 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput @@ -220,7 +219,6 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) - inputs_embeds = apply_prompt_adapter(self, inputs_embeds, position_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f9088449e007c..d83ee9a201c0b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -46,7 +46,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once @@ -283,7 +282,6 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) - hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8948e1cccb4ac..3faf54d292b99 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,7 +51,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.prompt_adapter.layers import apply_prompt_adapter from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once @@ -463,7 +462,6 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - hidden_states = apply_prompt_adapter(self, hidden_states, positions) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index b06b591b061b5..185a4bdfee22a 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -1,25 +1,56 @@ from dataclasses import dataclass -from typing import Tuple +from typing import Dict, List, Optional +import numpy import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) @dataclass -class PromptAdapterMapping: - # Per every token in input_ids: - index_mapping: Tuple[int, ...] - # Per sampled token: - prompt_mapping: Tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) - - -def apply_prompt_adapter(instance, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - if hasattr(instance, 'prefix_encoder'): - soft_prompt = instance.prefix_encoder.prompt_embedding - indices = (positions < soft_prompt.shape[0]) - hidden_states[indices] = soft_prompt[positions[indices]] - return hidden_states +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embedding_tensors: Dict[int, torch.Tensor] = {} + self.indices: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embedding_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if embeddings_tensor is not None: + self.embedding_tensors[index] = embeddings_tensor + + def set_mapping( + self, + base_indices: List[int], + ): + self.indices = base_indices + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + unique_indices = numpy.unique(self.indices) + for idx in unique_indices: + if idx != 0: + pa_idx = self.embedding_tensors[idx].prompt_embedding + mask = (self.indices == idx) + try: + n_adapters = sum(mask) // pa_idx.shape[0] + hidden_states[mask] = pa_idx.repeat(n_adapters, 1) + except Exception: + pass + return hidden_states diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 7676bfd259ed9..9b029bc697ee2 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -1,15 +1,15 @@ import logging import math -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Type -import torch from peft.utils import load_peft_weights from torch import nn from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.layers import ( + PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) logger = logging.getLogger(__name__) @@ -22,69 +22,6 @@ def get_prompt_adapter_id(): return _GLOBAL_PROMPT_ADAPTER_ID -def convert_mapping( - mapping: PromptAdapterMapping, - prompt_adapter_index_to_id: List[Optional[int]], max_prompt_adapters: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: - """Converts PromptAdapterMapping to index tensors. - - Args: - mapping: PromptAdapterMapping mapping rows in a batch to ids. - prompt_adapter_index_to_id: List mapping PromptAdapter ids to indices. - max_prompt_adapters: Maximum number of PromptAdapters. - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - PromptAdapter indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - PromptAdapter indices for sampler. For generation, this will be - same as base_indicies. For prefill, this will map requests - to PromptAdapter indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to PromptAdapter indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_promt_adapters. - indices_len: List of lengths of the above tensors. - Used to index into each tensor. It contains length for - (base_indices, sampler_indices, sampler_indices_padded). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - prompt_adapter_indices = index_mapping_indices.copy() - prompt_mapping: List[int] = [ - prompt_adapter_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - prompt_adapter_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - prompt_adapter_idx = (prompt_adapter_index_to_id.index( - index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1) - prompt_adapter_indices[i] = prompt_adapter_idx - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, prompt_adapter_indices - ] - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") - prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == - -1] = max_prompt_adapters - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1] - ] - return (base_indices, sampler_indices, sampler_indices_padded, indices_len) - - class PromptAdapterModel(AdapterModel): def __init__(self, @@ -133,16 +70,9 @@ def __init__( self.model.prompt_adapter_manager = self self.adapter_type = 'PromptAdapter' - self.base_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.indices_len: List[Optional[int]] = [None] * 3 + self.base_indices = [0] + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() self._last_mapping: Optional[PromptAdapterMapping] = None @property @@ -157,15 +87,6 @@ def adapter_slots(self) -> int: def capacity(self) -> int: return self.prompt_adapter_config.max_cpu_prompt_adapters - def reset_adapter(self): - try: - self.remove_all_prompt_adapters() - for module_name, module in self.model.named_modules(): - if 'Model' in (module.__class__.__name__): - del module.prefix_encoder - except Exception: - pass - def activate_prompt_adapter( self, prompt_adapter_id: int, @@ -187,10 +108,8 @@ def activate_prompt_adapter( logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", prompt_adapter_model.id, index) self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id - for module_name, module in self.model.named_modules(): - if 'Model' in (module.__class__.__name__): - module.prefix_encoder = prompt_adapter_model - break + for _, v in self.modules.items(): + v.set_prompt_adapter(prompt_adapter_id, prompt_adapter_model) return True @property @@ -201,9 +120,8 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int): try: index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) self.prompt_adapter_index_to_id[index] = None - for module_name, module in self.model.named_modules(): - if 'Model' in (module.__class__.__name__): - del module.prefix_encoder + for _, v in self.modules.items(): + v.reset_prompt_adapter(prompt_adapter_id) except ValueError: pass @@ -232,16 +150,30 @@ def remove_prompt_adapter(self): def _set_prompt_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - (base_indices, sampler_indices, sampler_indices_padded, - indices_len) = convert_mapping(mapping, - self.prompt_adapter_index_to_id, - self.prompt_adapter_slots + 1) - self.base_indices[:base_indices.shape[0]].copy_(base_indices) - self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - # Maintain the reference - self.indices_len[:] = indices_len + for k, v in self.modules.items(): + v.set_mapping(mapping.index_mapping) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices) + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module @property def set_prompt_adapter_mapping(self): diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index ea141fd31593c..77e5aefab20cf 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -44,9 +44,6 @@ def __init__( def is_enabled(self) -> bool: return True - def reset_adapter(self): - self._prompt_adapter_manager.reset_adapter() - def create_prompt_adapter_manager( self, model: torch.nn.Module, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 90393bbfbed2d..bc8a86065cce9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -476,12 +476,16 @@ def _prepare_model_input( mm_kwargs = self.multi_modal_input_processor(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) - + if prompt_adapter_id > 0: prompt_adapter_requests.add( seq_group_metadata.prompt_adapter_request) - prompt_adapter_index_mapping += [prompt_adapter_id] * query_len + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm prompt_adapter_prompt_mapping.extend( [prompt_adapter_id] * (query_len if seq_group_metadata.sampling_params @@ -786,11 +790,8 @@ def execute_model( self.set_active_loras(lora_requests, lora_mapping) if self.prompt_adapter_config: - if len(prompt_adapter_requests) >= 1: - self.set_active_prompt_adapters(prompt_adapter_requests, - prompt_adapter_mapping) - else: - self.reset_adapter() + self.set_active_prompt_adapters(prompt_adapter_requests, + prompt_adapter_mapping) # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata @@ -926,11 +927,6 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() - def reset_adapter(self): - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.reset_adapter() - def remove_all_prompt_adapters(self): if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") From f28b66e2f781227915a66c562a6e21ed2fa00daa Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 05:15:42 -0400 Subject: [PATCH 19/80] formatting --- vllm/adapter_commons/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py index 251ef7eb2ddee..3ed60678b52f5 100644 --- a/vllm/adapter_commons/layers.py +++ b/vllm/adapter_commons/layers.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Tuple + @dataclass class AdapterMapping: # Per every token in input_ids: From 220deef3b5336e072baae65bf8f20a7af7762e66 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 05:19:27 -0400 Subject: [PATCH 20/80] formatting yapf --- vllm/prompt_adapter/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 9b029bc697ee2..7b626e6396718 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -8,8 +8,8 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import ( - PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) +from vllm.prompt_adapter.layers import (PromptAdapterMapping, + VocabParallelEmbeddingWithPromptAdapter) logger = logging.getLogger(__name__) From 01b9bb878c20b9b4038efc4c76ad4987935335b3 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 05:21:50 -0400 Subject: [PATCH 21/80] formatting again --- vllm/prompt_adapter/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 7b626e6396718..9b029bc697ee2 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -8,8 +8,8 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import (PromptAdapterMapping, - VocabParallelEmbeddingWithPromptAdapter) +from vllm.prompt_adapter.layers import ( + PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) logger = logging.getLogger(__name__) From 2ea2796eafc77f900647e96c1a2438338070d5f1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 09:47:05 -0400 Subject: [PATCH 22/80] enable_adapter paramter --- tests/prompt_adapter/test_bloom.py | 2 +- vllm/engine/arg_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index 5f09ee1304f76..85ef7ebbf8145 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -30,7 +30,7 @@ def do_sample(llm, pa_name: str, pa_id: int): def test_twitter_prompt_adapter(): - llm = vllm.LLM(MODEL_PATH) + llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True) expected_output = ['complaint', 'no complaint'] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 59b9d125133c8..fcdda6da0b23b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -67,6 +67,7 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + enable_prompt_adapter: bool = False max_prompt_adapters: int = 1 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 @@ -503,6 +504,9 @@ def add_cli_args( 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') parser.add_argument('--max-prompt-adapters', type=int, default=EngineArgs.max_prompt_adapters, @@ -729,7 +733,7 @@ def create_engine_config(self, ) -> EngineConfig: ) prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters) + max_prompt_adapters=self.max_prompt_adapters) if self.enable_prompt_adapter else None if self.image_input_type: if (not self.image_token_id or not self.image_input_shape From 96fe5aea6058ae9b3da1275805cd2606863ef028 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 09:50:52 -0400 Subject: [PATCH 23/80] formatting --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fcdda6da0b23b..fc05b80c8aad5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -733,7 +733,8 @@ def create_engine_config(self, ) -> EngineConfig: ) prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters) if self.enable_prompt_adapter else None + max_prompt_adapters=self.max_prompt_adapters) \ + if self.enable_prompt_adapter else None if self.image_input_type: if (not self.image_token_id or not self.image_input_shape From 47725d9087842c63637ce49c2547c9e3d7c78e7f Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 12:36:12 -0400 Subject: [PATCH 24/80] adding test --- .../test_multi_adapter_inference.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/prompt_adapter/test_multi_adapter_inference.py diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py new file mode 100644 index 0000000000000..d2719c9eeb7de --- /dev/null +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -0,0 +1,57 @@ +import vllm +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' +pa_path2 = 'swapnilbp/angry_tweet_ptune' + + +def do_sample(engine): + + prompts = [ + ( + "Tweet text: I have complaints! Label: " , + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("hate_speech", 1, pa_path2 ,8)), + ( + "Tweet text: I have no problems Label: " , + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("hate_speech2", 2, pa_path2,8)), + ( + "Tweet text: I have complaints! Label: " , + SamplingParams(temperature=0.0, max_tokens=3), + None), + ( + "Tweet text: I have no problems Label: " , + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("complain", 3, pa_path,8)), + + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_twitter_prompt_adapter(): + engine_args = EngineArgs(model="bigscience/bloomz-560m", + max_prompt_adapters=3, + enable_prompt_adapter=True) + engine = LLMEngine.from_engine_args(engine_args) + expected_output = {' quot;I', 'hate speech', 'no complaint', 'not hate speech'} + assert do_sample(engine) == expected_output \ No newline at end of file From 638795a57dd0ee2f84d71b080e73b78d7bda42c8 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 12:36:38 -0400 Subject: [PATCH 25/80] adding test --- .../test_multi_adapter_inference.py | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py index d2719c9eeb7de..94139cdf6363f 100644 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -1,5 +1,4 @@ -import vllm -from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.prompt_adapter.request import PromptAdapterRequest MODEL_PATH = "bigscience/bloomz-560m" @@ -10,23 +9,17 @@ def do_sample(engine): prompts = [ - ( - "Tweet text: I have complaints! Label: " , - SamplingParams(temperature=0.0, max_tokens=3), - PromptAdapterRequest("hate_speech", 1, pa_path2 ,8)), - ( - "Tweet text: I have no problems Label: " , - SamplingParams(temperature=0.0, max_tokens=3), - PromptAdapterRequest("hate_speech2", 2, pa_path2,8)), - ( - "Tweet text: I have complaints! Label: " , - SamplingParams(temperature=0.0, max_tokens=3), - None), - ( - "Tweet text: I have no problems Label: " , - SamplingParams(temperature=0.0, max_tokens=3), - PromptAdapterRequest("complain", 3, pa_path,8)), - + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3), None), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3), + PromptAdapterRequest("complain", 3, pa_path, 8)), ] request_id = 0 @@ -39,9 +32,9 @@ def do_sample(engine): sampling_params, prompt_adapter_request=pa_request) request_id += 1 - + request_outputs = engine.step() - + for request_output in request_outputs: if request_output.finished: results.add(request_output.outputs[0].text) @@ -50,8 +43,10 @@ def do_sample(engine): def test_twitter_prompt_adapter(): engine_args = EngineArgs(model="bigscience/bloomz-560m", - max_prompt_adapters=3, - enable_prompt_adapter=True) + max_prompt_adapters=3, + enable_prompt_adapter=True) engine = LLMEngine.from_engine_args(engine_args) - expected_output = {' quot;I', 'hate speech', 'no complaint', 'not hate speech'} - assert do_sample(engine) == expected_output \ No newline at end of file + expected_output = { + ' quot;I', 'hate speech', 'no complaint', 'not hate speech' + } + assert do_sample(engine) == expected_output From f7d53b34052bc7809e4481375ba3995256b888d5 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 19:35:41 -0400 Subject: [PATCH 26/80] test case update --- tests/prompt_adapter/test_bloom.py | 2 +- tests/prompt_adapter/test_multi_adapter_inference.py | 6 +++--- tests/spec_decode/e2e/conftest.py | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index 85ef7ebbf8145..fd2a7a0ac168d 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -12,7 +12,7 @@ def do_sample(llm, pa_name: str, pa_id: int): current and paid. Can you do something about this? Label : ", "Tweet text : @nationalgridus Looks good thanks! Label : " ] - sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=3) + sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]) outputs = llm.generate(prompts, sampling_params, diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py index 94139cdf6363f..0317c49d7380a 100644 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -10,15 +10,15 @@ def do_sample(engine): prompts = [ ("Tweet text: I have complaints! Label: ", - SamplingParams(temperature=0.0, max_tokens=3), + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), ("Tweet text: I have no problems Label: ", - SamplingParams(temperature=0.0, max_tokens=3), + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), ("Tweet text: I have complaints! Label: ", SamplingParams(temperature=0.0, max_tokens=3), None), ("Tweet text: I have no problems Label: ", - SamplingParams(temperature=0.0, max_tokens=3), + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), PromptAdapterRequest("complain", 3, pa_path, 8)), ] diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index f8a6de54653c1..163465dc1cfe3 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -17,6 +17,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor.utils import set_random_seed from vllm.multimodal import MultiModalData from vllm.outputs import RequestOutput @@ -98,6 +99,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> List[RequestOutput]: if prompts is None: From 16f4037bc40d0ab3db2056b0b298b92ecb96c0a7 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 5 Jun 2024 19:36:06 -0400 Subject: [PATCH 27/80] formatting --- tests/prompt_adapter/test_bloom.py | 4 +++- tests/spec_decode/e2e/conftest.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index fd2a7a0ac168d..7c13a81b6f2cb 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -12,7 +12,9 @@ def do_sample(llm, pa_name: str, pa_id: int): current and paid. Can you do something about this? Label : ", "Tweet text : @nationalgridus Looks good thanks! Label : " ] - sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]) + sampling_params = vllm.SamplingParams(temperature=0.0, + max_tokens=3, + stop_token_ids=[3]) outputs = llm.generate(prompts, sampling_params, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 163465dc1cfe3..820d9af235051 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -17,10 +17,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor.utils import set_random_seed from vllm.multimodal import MultiModalData from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext From f2f3cbc54ab1c1047197f4957e98fa347a7fa854 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Thu, 13 Jun 2024 13:37:09 -0400 Subject: [PATCH 28/80] resetting --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bc8a86065cce9..b423b3d199ee6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -442,6 +442,7 @@ def _prepare_model_input( input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id prompt_adapter_id = seq_group_metadata.prompt_adapter_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id if is_prompt: assert len(seq_ids) == 1 From 0fc0c34971e980203b3e7b627f7eccc43e1d678c Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Thu, 13 Jun 2024 13:42:43 -0400 Subject: [PATCH 29/80] formatting --- vllm/entrypoints/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d1762de87bd8d..156fb8b463521 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -532,7 +532,8 @@ def _add_request( self, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[List[LoRARequest], + LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) From 4eb47d685196ce9247923e61f8b397109fe26506 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Thu, 13 Jun 2024 13:43:49 -0400 Subject: [PATCH 30/80] formatting --- vllm/worker/model_runner.py | 3 --- vllm/worker/worker.py | 1 - 2 files changed, 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b423b3d199ee6..712c683042a3e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,10 +25,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 91277ad1891df..427cd2b34c306 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -17,7 +17,6 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine From e69842bd74eb8bc81b3f96ab7cc83a5b03a4bef8 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Thu, 13 Jun 2024 13:44:05 -0400 Subject: [PATCH 31/80] formatting --- vllm/entrypoints/llm.py | 2 +- vllm/worker/model_runner.py | 2 +- vllm/worker/worker.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 156fb8b463521..8e684a214ab0b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -532,7 +532,7 @@ def _add_request( self, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 712c683042a3e..c9fbfd4f2d9b2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,12 +20,12 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 427cd2b34c306..1a62fe91d310c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -16,8 +16,8 @@ set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner From 5c174806f785443397b2a76957d2e09de2f9e604 Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 13 Jun 2024 13:04:58 -0700 Subject: [PATCH 32/80] Fix async engine --- vllm/engine/async_llm_engine.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 60ccff09abe5d..510c8c1a76021 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -260,6 +260,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -275,6 +276,11 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [ + 0 + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + prompt_token_ids + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -295,7 +301,10 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request) + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, From e62cbb5692a1c884d36d185054bdee521395b39e Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 13 Jun 2024 14:01:07 -0700 Subject: [PATCH 33/80] Initial implementation of openai entrypoint Assumes the interface for prompt adapters and lora modules remains completely separate. --- vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/cli_args.py | 21 ++++++- vllm/entrypoints/openai/serving_completion.py | 11 +++- vllm/entrypoints/openai/serving_engine.py | 59 +++++++++++++++++-- 4 files changed, 86 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e7503b9655830..ec35a81c0b899 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -87,7 +87,7 @@ async def health() -> Response: @app.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) @@ -216,7 +216,8 @@ async def authentication(request: Request, call_next): args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules, + args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4c0cb1e4f3e49..fd142643aec75 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,8 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) class LoRAParserAction(argparse.Action): @@ -22,6 +23,16 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, lora_list) +class PromptAdapterParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + adapter_list = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + def make_arg_parser(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -64,6 +75,14 @@ def make_arg_parser(): action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") parser.add_argument("--chat-template", type=nullable_str, default=None, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 64671e21a724d..a91df755f4bc2 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,7 +18,8 @@ CompletionStreamResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -61,11 +62,13 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]]): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=prompt_adapters) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -96,6 +99,7 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + prompt_adapter_request = self._maybe_get_prompt_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend @@ -133,6 +137,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6b5a62efc7f20..a0cc0d04f44cd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -15,12 +15,19 @@ ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + @dataclass class LoRAModulePath: name: str @@ -29,9 +36,14 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): super().__init__() self.engine = engine @@ -58,6 +70,19 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, ) for i, lora in enumerate(lora_modules, start=1) ] + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with open(prompt_adapter.local_path) as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -73,6 +98,13 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.served_model_names[0], + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(prompt_adapter_cards) model_cards.extend(lora_cards) return ModelList(data=model_cards) @@ -106,6 +138,11 @@ async def _check_model( return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.prompt_adapter_requests + ]: + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", @@ -120,8 +157,22 @@ def _maybe_get_lora( for lora in self.lora_requests: if request.model == lora.lora_name: return lora + return None + # if _check_model has been called earlier, this will be unreachable + #raise ValueError(f"The model `{request.model}` does not exist.") + + def _maybe_get_prompt_adapter( + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] + ) -> Optional[PromptAdapterRequest]: + if request.model in self.served_model_names: + return None + for prompt_adapter in self.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return prompt_adapter + return None # if _check_model has been called earlier, this will be unreachable - raise ValueError(f"The model `{request.model}` does not exist.") + #raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self, From 612d6c5ef4fa5490d0292eb4c5a7f8af64d1cb38 Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 13 Jun 2024 14:23:34 -0700 Subject: [PATCH 34/80] Fixes --- vllm/engine/async_llm_engine.py | 3 ++- vllm/entrypoints/openai/serving_engine.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 510c8c1a76021..f5bea488a04b8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,7 +279,8 @@ async def process_model_inputs_async( if prompt_adapter_request: prompt_token_ids = [ 0 - ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + prompt_token_ids + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ + prompt_token_ids return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a0cc0d04f44cd..e6d99686e72da 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -73,7 +73,7 @@ def __init__( self.prompt_adapter_requests = [] if prompt_adapters is not None: for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with open(prompt_adapter.local_path) as f: + with open(f"./{prompt_adapter.local_path}/adapter_config.json") as f: adapter_config = json.load(f) num_virtual_tokens = adapter_config["num_virtual_tokens"] self.prompt_adapter_requests.append( From 894b9ba4eca99365a1b5dde704e532369c14603d Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 18 Jun 2024 12:13:54 -0400 Subject: [PATCH 35/80] async changes --- format.sh | 14 +++++++------- vllm/engine/async_llm_engine.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/format.sh b/format.sh index 65f01b6eb303c..fd512a34bb22a 100755 --- a/format.sh +++ b/format.sh @@ -26,7 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) -# CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') +CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -36,12 +36,12 @@ tool_version_check() { fi } -tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-lint.txt | cut -d'=' -f3)" +tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-lint.txt | cut -d'=' -f3)" +tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-lint.txt | cut -d'=' -f3)" +tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-lint.txt | cut -d'=' -f3)" +tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-lint.txt | cut -d'=' -f3)" +tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-lint.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 60ccff09abe5d..cef10e4142e10 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -260,6 +260,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -275,6 +276,12 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [ + 0 + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ + prompt_token_ids + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -295,7 +302,8 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request) + request_id=request_id, inputs=inputs, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -305,6 +313,7 @@ async def add_request_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) + async def check_health_async(self) -> None: self.model_executor.check_health() From 155ad76f8730d2e736106a3f30286d0f6d5c80b1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 18 Jun 2024 13:22:21 -0400 Subject: [PATCH 36/80] formattign --- vllm/engine/async_llm_engine.py | 9 +++++---- vllm/engine/llm_engine.py | 16 +++++++--------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9af51bd05be42..5f527956537cc 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -306,7 +306,9 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request, + request_id=request_id, + inputs=inputs, + lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) self._add_processed_request( @@ -315,8 +317,8 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request ) async def check_health_async(self) -> None: @@ -616,8 +618,7 @@ async def add_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) return stream diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2af88851db028..281625ab287b0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,8 +8,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig, PromptAdapterConfig) + ParallelConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -216,7 +216,7 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.prompt_adapter_config = prompt_adapter_config self.observability_config = observability_config or ObservabilityConfig( - + ) self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: @@ -459,8 +459,8 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], - trace_headers: Optional[Dict[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Dict[str, str]] = None, ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -479,8 +479,7 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -595,8 +594,8 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request ) def _create_sequence_group_with_sampling( @@ -636,8 +635,7 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request - ) + prompt_adapter_request=prompt_adapter_request) return seq_group From 042c9f1861968d1d664bff75983eb036f318874c Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 18 Jun 2024 13:23:22 -0400 Subject: [PATCH 37/80] formatting --- format.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/format.sh b/format.sh index 17dafb4b5deac..5edc868f9f70c 100755 --- a/format.sh +++ b/format.sh @@ -111,7 +111,6 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml -mypy vllm/model_executor --config-file pyproject.toml mypy vllm/prompt_adapter --config-file pyproject.toml mypy tests --config-file pyproject.toml From 0e46a06fd23198c481d8ee33eca08205f2049dc9 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:10:28 -0400 Subject: [PATCH 38/80] adding dtype flexibility + pa lora refactor --- tests/lora/test_lora_manager.py | 264 +++++++++--------- .../test_multi_adapter_inference.py | 4 +- tests/prompt_adapter/test_pa_lora.py | 54 ++++ vllm/adapter_commons/models.py | 22 +- vllm/adapter_commons/worker_manager.py | 33 +-- vllm/config.py | 11 + vllm/core/scheduler.py | 8 - vllm/engine/llm_engine.py | 3 + vllm/lora/models.py | 71 ++--- vllm/lora/worker_manager.py | 79 ++---- vllm/prompt_adapter/layers.py | 22 +- vllm/prompt_adapter/models.py | 115 +++----- vllm/prompt_adapter/request.py | 4 +- vllm/prompt_adapter/worker_manager.py | 85 ++---- vllm/worker/model_runner.py | 24 +- 15 files changed, 346 insertions(+), 453 deletions(-) create mode 100644 tests/prompt_adapter/test_pa_lora.py diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 51a56b121ae2c..bbf0e0cda0453 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] is None - assert manager.add_lora(model_lora2) - assert manager.activate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] is None - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 @@ -173,40 +173,40 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.add_lora(model_lora2) - assert manager.deactivate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.deactivate_adapter(3) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 @@ -228,132 +228,133 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity - assert manager.add_lora(model_lora1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(1) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(1) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 # Add over capacity - assert manager.add_lora(model_lora3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(3) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(3) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already - assert not manager.add_lora(model_lora3) - assert not manager.activate_lora(3) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora3) + assert not manager.activate_adapter(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {3, 2} + assert set(manager.list_adapters()) == {3, 2} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {2} + assert set(manager.list_adapters()) == {2} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 2 - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {4} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) - assert not manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert not manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) -def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = LRUCacheWorkerLoRAManager( + worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -362,68 +363,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, ], mapping) -def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = WorkerLoRAManager( + worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + assert worker_adapter_manager.list_adapters() == {1, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager.list_adapters() == {1, 2, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None - assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + assert worker_adapter_manager.list_adapters() == {1} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + assert worker_adapter_manager.list_adapters() == {6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -461,8 +463,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up): assert isinstance(model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA) - assert manager.add_lora(model_lora) - assert manager.add_lora(model_lora1) + assert manager.add_adapter(model_lora) + assert manager.add_adapter(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py index 0317c49d7380a..0cc8c8bc50fd0 100644 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -41,8 +41,8 @@ def do_sample(engine): return results -def test_twitter_prompt_adapter(): - engine_args = EngineArgs(model="bigscience/bloomz-560m", +def test_multi_prompt_adapters(): + engine_args = EngineArgs(model=MODEL_PATH, max_prompt_adapters=3, enable_prompt_adapter=True) engine = LLMEngine.from_engine_args(engine_args) diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py new file mode 100644 index 0000000000000..d3a755b30bc79 --- /dev/null +++ b/tests/prompt_adapter/test_pa_lora.py @@ -0,0 +1,54 @@ +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.lora.request import LoRARequest +from huggingface_hub import snapshot_download + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") +lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + +def do_sample(engine): + + prompts = [ + ('Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label :', + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech", 1, pa_path, 8), + LoRARequest("sql_test", 1, lora_path)), + + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, max_tokens=100, stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, 8), + LoRARequest("sql_test", 1, lora_path)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request, lora_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request, + lora_request=lora_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + print(results) + return results + + +def test_lora_prompt_adapter(): + engine_args = EngineArgs(model=MODEL_PATH, + enable_prompt_adapter=True, + enable_lora=True, + max_num_seqs=60) + engine = LLMEngine.from_engine_args(engine_args) + expected_output = { + " complaint", " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " + } + assert do_sample(engine) == expected_output diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index 403f5115ebc34..e949331f7ba51 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,4 +1,4 @@ -from abc import abstractproperty +from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Hashable, Optional, TypeVar from torch import nn @@ -9,12 +9,12 @@ logger = init_logger(__name__) -class AdapterModel: +class AdapterModel(ABC): def __init__(self, model_id=None): self.id = model_id - @classmethod + @abstractmethod def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): # Common initialization code # Load weights or embeddings from local checkpoint @@ -37,7 +37,7 @@ def _on_remove(self, key: Hashable, value: T): return super()._on_remove(key, value) -class AdapterModelManager: +class AdapterModelManager(ABC): def __init__( self, @@ -57,14 +57,17 @@ def __init__( def __len__(self) -> int: return len(self._registered_adapters) - @abstractproperty + @property + @abstractmethod def adapter_slots(self): ... - @abstractproperty + @property + @abstractmethod def capacity(self): ... + @abstractmethod def _deactivate_adapter(self, adapter_id: int): raise NotImplementedError("Subclasses must implement this method.") @@ -75,16 +78,14 @@ def deactivate_adapter(self, adapter_id: int) -> bool: return True return False - def activate_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError("Subclasses must implement this method.") - + @abstractmethod def _add_adapter(self, adapter: Any): raise NotImplementedError("Subclasses must implement this method.") def add_adapter(self, adapter: Any) -> bool: if adapter.id not in self._registered_adapters: if len(self._registered_adapters) >= self.capacity: - raise RuntimeError("No free " + self.adapter_type + " slots.") + raise RuntimeError(f'No free {self.adapter_type} slots.') self._add_adapter(adapter) return True return False @@ -94,6 +95,7 @@ def set_adapter_mapping(self, mapping: Any) -> None: self._set_adapter_mapping(mapping) self._last_mapping = mapping + @abstractmethod def _set_adapter_mapping(self, mapping: Any) -> None: raise NotImplementedError("Subclasses must implement this method.") diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index 80584620343c4..516a1d57d9683 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Any, Optional, Set import torch @@ -9,26 +9,15 @@ class AbstractWorkerManager(ABC): def __init__(self, device: torch.device): self.device = device - @abstractproperty - def _model_manager(self): - ... - - @abstractproperty - def is_enabled(self) -> bool: - ... - + @property @abstractmethod - def create_manager(self, model: torch.nn.Module) -> Any: + def is_enabled(self) -> bool: ... def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: self._apply_adapters(requests) - self._model_manager.set_adapter_mapping(mapping) - - @abstractmethod - def add_dummy_adapter(self, request: Any) -> bool: - ... + self._adapter_manager.set_adapter_mapping(mapping) @abstractmethod def _load_adapter(self, request: Any) -> Any: @@ -38,8 +27,8 @@ def add_adapter(self, adapter_request: Any) -> bool: if adapter_request.adapter_id in self.list_adapters(): return False loaded_adapter = self._load_adapter(adapter_request) - loaded = self._model_manager.add_adapter(loaded_adapter) - self._model_manager.activate_adapter(loaded_adapter.id) + loaded = self._adapter_manager.add_adapter(loaded_adapter) + self._adapter_manager.activate_adapter(loaded_adapter.id) return loaded def _apply_adapters(self, adapter_requests: Set[Any]) -> None: @@ -48,11 +37,11 @@ def _apply_adapters(self, adapter_requests: Set[Any]) -> None: adapter_request.adapter_id: adapter_request for adapter_request in adapter_requests if adapter_request } - if len(models_map) > self._model_manager.adapter_slots: + if len(models_map) > self._adapter_manager.adapter_slots: raise RuntimeError( f"Number of requested models ({len(models_map)}) is greater " "than the number of GPU model slots " - f"({self._model_manager.adapter_slots}).") + f"({self._adapter_manager.adapter_slots}).") new_models = set(models_map) models_to_add = new_models - models_that_exist @@ -65,10 +54,10 @@ def _apply_adapters(self, adapter_requests: Set[Any]) -> None: self.add_adapter(models_map[adapter_id]) def remove_adapter(self, adapter_id: int) -> bool: - return self._model_manager.remove_adapter(adapter_id) + return self._adapter_manager.remove_adapter(adapter_id) def remove_all_adapters(self): - self._model_manager.remove_all_adapters() + self._adapter_manager.remove_all_adapters() def list_adapters(self) -> Set[int]: - return set(self._model_manager.list_adapters()) + return set(self._adapter_manager.list_adapters()) diff --git a/vllm/config.py b/vllm/config.py index c2e47c5280e85..b290a28477678 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1112,6 +1112,7 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): class PromptAdapterConfig: max_prompt_adapters: int max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None def __post_init__(self): if self.max_prompt_adapters < 1: @@ -1120,6 +1121,13 @@ def __post_init__(self): if self.max_cpu_prompt_adapters is None: self.max_cpu_prompt_adapters = self.max_prompt_adapters + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + @dataclass class VisionLanguageConfig: @@ -1425,6 +1433,9 @@ def __post_init__(self): self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def to_dict(self): """Return the configs as a dictionary, for use in **kwargs. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index be9992b3deaa6..26fe602d8441a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -141,8 +141,6 @@ def __post_init__(self): self._sort_by_lora_ids() self.num_prompt_adapters: int = len(self.prompt_adapter_requests) - if self.num_prompt_adapters > 0: - self._sort_by_prompt_adapter_ids() def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. @@ -154,12 +152,6 @@ def _sort_by_lora_ids(self): self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) - def _sort_by_prompt_adapter_ids(self): - self.scheduled_seq_groups = sorted( - self.scheduled_seq_groups, - key=lambda g: - (g.seq_group.prompt_adapter_id, g.seq_group.request_id)) - @property def lora_requests(self) -> Set[LoRARequest]: return { diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 281625ab287b0..39666bcd20d90 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -442,6 +442,9 @@ def _verify_args(self) -> None: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e5833c2c64a17..789450aa4ea55 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -444,7 +444,7 @@ def lora_slots(self) -> int: def adapter_slots(self) -> int: return self.lora_slots - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: @@ -472,24 +472,15 @@ def activate_lora( module.reset_lora(index) return True - @property - def activate_adapter(self): - return self.activate_lora - - def _deactivate_lora(self, lora_id: int): + def _deactivate_adapter(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None except ValueError: pass - @property - def _deactivate_adapter(self): - return self._deactivate_lora - - @property - def deactivate_lora(self): - return self.deactivate_adapter + def deactivate_adapter(self, lora_id: int) -> bool: + return super().deactivate_adapter(lora_id) def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: @@ -506,28 +497,20 @@ def _set_long_lora_context(self, lora: LoRAModel): if offsets: self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_lora(self, lora: LoRAModel): + def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - @property - def _add_adapter(self): - return self._add_lora - - def add_lora(self, lora: LoRAModel): + def add_adapter(self, lora: LoRAModel): logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - return self.add_adapter(lora) - - @property - def remove_lora(self): - return self.remove_adapter + return super().add_adapter(lora) # TODO see if this can be vectorized - def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -549,32 +532,12 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - @property - def set_lora_mapping(self): - return self.set_adapter_mapping - - @property - def _set_adapter_mapping(self): - return self._set_lora_mapping - - @property - def list_loras(self): - return self.list_adapters - - @property - def get_lora(self): - return self.get_adapter - - def remove_all_loras(self): + def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots self._active_adapters.clear() - @property - def remove_all_adapters(self): - return self.remove_all_loras - def _create_lora_modules(self): for module_name, module in self.model.named_modules( remove_duplicate=False): @@ -721,7 +684,7 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: class LoRALRUCache(AdapterLRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], - None]): + bool]): super().__init__(capacity, deactivate_lora_fn) @@ -739,22 +702,22 @@ def __init__( super().__init__(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config) self._registered_adapters: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_lora) + self.capacity, self.deactivate_adapter) self._active_adapters: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_lora) + self.lora_slots, self._deactivate_adapter) def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) - def add_lora(self, lora: LoRAModel) -> bool: + def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_adapters: - self._add_lora(lora) + self._add_adapter(lora) was_added = True else: # We always touch to update the LRU cache order @@ -762,19 +725,19 @@ def add_lora(self, lora: LoRAModel) -> bool: was_added = False return was_added - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: if lora_id not in self._active_adapters and len( self._active_adapters) >= self.lora_slots: self._active_adapters.remove_oldest() - result = super().activate_lora(lora_id) + result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order self._active_adapters.touch(lora_id) return result - def remove_oldest_lora(self) -> bool: + def remove_oldest_adapter(self) -> bool: if len(self._registered_adapters) > 0: self._registered_adapters.remove_oldest() return True diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 699609a656ba9..f215b0ab903d1 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -44,7 +44,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings super().__init__(device) # Lazily initialized by create_lora_manager. - self._lora_manager: LoRAModelManager + self._adapter_manager: LoRAModelManager @contextmanager def dummy_lora_cache(self): @@ -70,16 +70,12 @@ def create_lora_manager( lora_config=self.lora_config, lora_manager_cls=self._manager_cls, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - @property - def set_active_loras(self): - return self.set_active_adapters - - def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._lora_manager.model + model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping expected_lora_modules: List[str] = [] @@ -115,53 +111,17 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - if lora_request.lora_int_id in self.list_loras(): + if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): dummy_lora = self._cached_dummy_lora.clone( lora_request.lora_int_id) else: - dummy_lora = self._lora_manager.create_dummy_lora( + dummy_lora = self._adapter_manager.create_dummy_lora( lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora - return self._lora_manager.add_lora(dummy_lora) - - @property - def add_dummy_adapter(self): - return self.add_dummy_lora - - @property - def create_manager(self): - return self.create_lora_manager - - @property - def _load_adapter(self): - return self._load_lora - - @property - def _model_manager(self): - return self._lora_manager - - @property - def add_lora(self): - return self.add_adapter - - @property - def remove_lora(self): - return self.remove_adapter - - @property - def remove_all_loras(self): - return self.remove_all_adapters - - @property - def list_loras(self): - return self.list_adapters - - @property - def _apply_loras(self): - return self._apply_adapters + return self._adapter_manager.add_adapter(dummy_lora) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -185,7 +145,7 @@ def create_lora_manager( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: @@ -193,26 +153,27 @@ def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.lora_slots: + if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots}).") for lora in loras_map.values(): - self.add_lora(lora) + self.add_adapter(lora) def add_adapter(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id not in self.list_loras(): + if lora_request.lora_int_id not in self.list_adapters(): # Remove before we load the new lora to save memory - if len(self._lora_manager) + 1 > self._lora_manager.capacity: - assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) - self._lora_manager.remove_oldest_lora() - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + lora = self._load_adapter(lora_request) + loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora( + loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None - self._lora_manager.activate_lora(lora_request.lora_int_id) + self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 185a4bdfee22a..88b1f5fc6fd32 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy import torch from torch import nn @@ -21,7 +20,8 @@ def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer self.embedding_tensors: Dict[int, torch.Tensor] = {} - self.indices: torch.Tensor + self.indices_gpu: torch.Tensor + self.flag: bool = False def reset_prompt_adapter(self, index: int): self.embedding_tensors[index] = 0 @@ -39,18 +39,18 @@ def set_mapping( self, base_indices: List[int], ): - self.indices = base_indices + self.indices_gpu = torch.tensor(base_indices, device="cuda") + self.flag = torch.sum(self.indices_gpu) > 0 def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) - unique_indices = numpy.unique(self.indices) - for idx in unique_indices: - if idx != 0: - pa_idx = self.embedding_tensors[idx].prompt_embedding - mask = (self.indices == idx) - try: + if self.flag: + unique_indices = torch.unique(self.indices_gpu) + for idx in unique_indices: + if idx != 0: + pa_idx = self.embedding_tensors[ + idx.item()].prompt_embedding + mask = (self.indices_gpu == idx) n_adapters = sum(mask) // pa_idx.shape[0] hidden_states[mask] = pa_idx.repeat(n_adapters, 1) - except Exception: - pass return hidden_states diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 9b029bc697ee2..516503c3896ae 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -2,6 +2,7 @@ import math from typing import Callable, Dict, List, Optional, Type +import torch from peft.utils import load_peft_weights from torch import nn @@ -29,18 +30,18 @@ def __init__(self, num_virtual_tokens=None, prompt_embedding=None) -> None: self.id = prompt_adapter_id - self.kv_cache = None self.prompt_embedding = prompt_embedding self.num_virtual_tokens = num_virtual_tokens @classmethod - def from_local_checkpoint(cls, - adapter_model_and_path, - prompt_adapter_id, - torch_device='cuda') -> "PromptAdapterModel": - adapters_weights = load_peft_weights(adapter_model_and_path, - torch_device) - prompt_embedding = adapters_weights["prompt_embeddings"].half() + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + device: str = "cuda", + dtype: Optional[torch.dtype] = None) -> "PromptAdapterModel": + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to(dtype) num_virtual_tokens = prompt_embedding.shape[0] return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) @@ -62,8 +63,8 @@ def __init__( """ self.model: nn.Module = model # Dict instead of a Set for compatibility with LRUCache. - self.prompt_adapter_index_to_id: List[Optional[int]] =\ - [None] * self.prompt_adapter_slots + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.prompt_adapter_config = prompt_adapter_config @@ -87,7 +88,7 @@ def adapter_slots(self) -> int: def capacity(self) -> int: return self.prompt_adapter_config.max_cpu_prompt_adapters - def activate_prompt_adapter( + def activate_adapter( self, prompt_adapter_id: int, ) -> bool: @@ -96,9 +97,9 @@ def activate_prompt_adapter( if prompt_adapter_id in self._active_adapters: return False first_free_slot = next( - ((i, prompt_adapter_id) for i, prompt_adapter_id in \ - enumerate(self.prompt_adapter_index_to_id) - if prompt_adapter_id is None), None) + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) if first_free_slot is None: raise ValueError("No free prompt_adapter slots") index, _ = first_free_slot @@ -112,11 +113,7 @@ def activate_prompt_adapter( v.set_prompt_adapter(prompt_adapter_id, prompt_adapter_model) return True - @property - def activate_adapter(self): - return self.activate_prompt_adapter - - def _deactivate_prompt_adapter(self, prompt_adapter_id: int): + def _deactivate_adapter(self, prompt_adapter_id: int): try: index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) self.prompt_adapter_index_to_id[index] = None @@ -125,31 +122,10 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int): except ValueError: pass - @property - def _deactivate_adapter(self): - return self._deactivate_prompt_adapter - - @property - def deactivate_prompt_adapter(self): - return self.deactivate_adapter - - def _add_prompt_adapter(self, prompt_adapter: PromptAdapterModel): + def _add_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter - @property - def _add_adapter(self): - return self._add_prompt_adapter - - @property - def add_prompt_adapter(self): - return self.add_adapter - - @property - def remove_prompt_adapter(self): - return self.remove_adapter - - def _set_prompt_adapter_mapping(self, - mapping: PromptAdapterMapping) -> None: + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: for k, v in self.modules.items(): v.set_mapping(mapping.index_mapping) @@ -163,6 +139,7 @@ def _create_prompt_adapter_modules(self): self.register_module(module.__class__.__name__, replaced_module) replaced_module.set_mapping(self.base_indices) + break def replace_submodule(self, model: nn.Module, module_name: str, new_module: nn.Module) -> nn.Module: @@ -175,32 +152,12 @@ def replace_submodule(self, model: nn.Module, module_name: str, def register_module(self, module_name: str, module: nn.Module): self.modules[module_name] = module - @property - def set_prompt_adapter_mapping(self): - return self.set_adapter_mapping - - @property - def _set_adapter_mapping(self): - return self._set_prompt_adapter_mapping - - @property - def list_prompt_adapters(self): - return self.list_adapters - - @property - def get_prompt_adapter(self): - return self.get_adapter - - def remove_all_prompt_adapters(self): + def remove_all_adapters(self): """Remove all PromptAdapterModel from the manager.""" self._registered_adapters.clear() self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots self._active_adapters.clear() - @property - def remove_all_adapters(self): - return self.remove_all_prompt_adapters - class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): @@ -220,14 +177,12 @@ def __init__( prompt_adapter_config: PromptAdapterConfig, ): self.prompt_adapter_config = prompt_adapter_config - super().__init__(model, max_num_seqs, \ - max_num_batched_tokens, prompt_adapter_config) - self._registered_adapters: PromptAdapterLRUCache = \ - PromptAdapterLRUCache(self.capacity, - self.deactivate_prompt_adapter) - self._active_adapters: PromptAdapterLRUCache = \ - PromptAdapterLRUCache(self.prompt_adapter_slots, - self._deactivate_prompt_adapter) + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) def list_adapters(self) -> Dict[int, PromptAdapterModel]: """List all registered PromptAdapterModel.""" @@ -236,7 +191,7 @@ def list_adapters(self) -> Dict[int, PromptAdapterModel]: def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: """Add a PromptAdapterModel to the manager.""" if prompt_adapter.id not in self._registered_adapters: - self._add_prompt_adapter(prompt_adapter) + self._add_adapter(prompt_adapter) was_added = True else: # We always touch to update the LRU cache order @@ -251,12 +206,12 @@ def activate_adapter( if prompt_adapter_id not in self._active_adapters and len( self._active_adapters) >= self.prompt_adapter_slots: self._active_adapters.remove_oldest() - result = super().activate_prompt_adapter(prompt_adapter_id) + result = super().activate_adapter(prompt_adapter_id) # We always touch to update the LRU cache order self._active_adapters.touch(prompt_adapter_id) return result - def remove_oldest_prompt_adapter(self) -> bool: + def remove_oldest_adapter(self) -> bool: if len(self._registered_adapters) > 0: self._registered_adapters.remove_oldest() return True @@ -268,12 +223,14 @@ def create_prompt_adapter_manager( max_num_seqs: int, max_num_batched_tokens: int, prompt_adapter_config: PromptAdapterConfig, - prompt_adapter_manager_cls: Type[PromptAdapterModelManager] \ - = PromptAdapterModelManager, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, **kwargs) -> PromptAdapterModelManager: """Create a PromptAdapterModel for a given model.""" prompt_adapter_manager = prompt_adapter_manager_cls( - model=model, max_num_seqs=max_num_seqs, \ - max_num_batched_tokens=max_num_batched_tokens, \ - prompt_adapter_config=prompt_adapter_config, **kwargs) + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py index 31ba38c420583..c0c98cf72bbae 100644 --- a/vllm/prompt_adapter/request.py +++ b/vllm/prompt_adapter/request.py @@ -13,7 +13,9 @@ class PromptAdapterRequest(AdapterRequest): prompt_adapter_id: int prompt_adapter_local_path: str prompt_adapter_num_virtual_tokens: int - __hash__ = AdapterRequest.__hash__ + + def __hash__(self): + return super().__hash__() @property def adapter_id(self): diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 77e5aefab20cf..7bfdeec15b3e8 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional, Set, Type +from typing import Any, Set, Type import torch @@ -32,8 +32,7 @@ def __init__( prompt_adapter_config: PromptAdapterConfig, prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel ): - self._prompt_adapter_manager: Optional[ - PromptAdapterModelManager] = None + self._adapter_manager: PromptAdapterModelManager self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = max_num_batched_tokens self._prompt_adapter_model_cls = prompt_adapter_model_cls @@ -55,23 +54,19 @@ def create_prompt_adapter_manager( prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_manager_cls=self._manager_cls, ) - self._prompt_adapter_manager = prompt_adapter_manager + self._adapter_manager = prompt_adapter_manager return prompt_adapter_manager.model - @property - def set_active_prompt_adapters(self): - return self.set_active_adapters - - def _load_prompt_adapter( + def _load_adapter( self, prompt_adapter_request: PromptAdapterRequest ) -> PromptAdapterModel: try: - prompt_adapter = self._prompt_adapter_model_cls\ - .from_local_checkpoint( + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( prompt_adapter_request.prompt_adapter_local_path, prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, - torch_device=str(self.device) - ) + device=str(self.device), + dtype=self.prompt_adapter_config.prompt_adapter_dtype)) except Exception as e: raise RuntimeError( f"Loading prompt_adapter " @@ -83,42 +78,6 @@ def add_dummy_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: return True - @property - def add_dummy_adapter(self): - return self.add_dummy_prompt_adapter - - @property - def create_manager(self): - return self.create_prompt_adapter_manager - - @property - def _load_adapter(self): - return self._load_prompt_adapter - - @property - def _model_manager(self): - return self._prompt_adapter_manager - - @property - def add_prompt_adapter(self): - return self.add_adapter - - @property - def remove_prompt_adapter(self): - return self.remove_adapter - - @property - def remove_all_prompt_adapters(self): - return self.remove_all_adapters - - @property - def list_prompt_adapters(self): - return self.list_adapters - - @property - def _apply_prompt_adapters(self): - return self._apply_adapters - class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): """WorkerPromptAdapterManager that manages @@ -142,8 +101,8 @@ def create_prompt_adapter_manager( max_num_batched_tokens=self.max_num_batched_tokens, prompt_adapter_config=self.prompt_adapter_config, prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) - self._prompt_adapter_manager: \ - LRUCachePromptAdapterModelManager = prompt_adapter_manager + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) return prompt_adapter_manager.model def _apply_adapters( @@ -154,31 +113,29 @@ def _apply_adapters( if prompt_adapter_request } if len(prompt_adapters_map - ) > self._prompt_adapter_manager.prompt_adapter_slots: + ) > self._adapter_manager.prompt_adapter_slots: raise RuntimeError( f"Number of requested prompt_adapters " f"({len(prompt_adapters_map)}) is greater " "than the number of GPU prompt_adapter slots " - f"({self._prompt_adapter_manager.prompt_adapter_slots}).") + f"({self._adapter_manager.prompt_adapter_slots}).") for prompt_adapter in prompt_adapters_map.values(): - self.add_prompt_adapter(prompt_adapter) + self.add_adapter(prompt_adapter) def add_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: - if prompt_adapter_request.prompt_adapter_id not in \ - self.list_prompt_adapters(): + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): # Remove before we load the new prompt_adapter to save memory - if len(self._prompt_adapter_manager - ) + 1 > self._prompt_adapter_manager.capacity: - self._prompt_adapter_manager.remove_oldest_prompt_adapter() - prompt_adapter = self._load_prompt_adapter(prompt_adapter_request) - loaded = self._prompt_adapter_manager.add_prompt_adapter( - prompt_adapter) + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) else: # If the prompt_adapter is already loaded, just touch it to # update its position in the caches - loaded = self._prompt_adapter_manager.get_prompt_adapter( + loaded = self._adapter_manager.get_adapter( prompt_adapter_request.prompt_adapter_id) - self._prompt_adapter_manager.activate_prompt_adapter( + self._adapter_manager.activate_adapter( prompt_adapter_request.prompt_adapter_id) return loaded diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d627cfd428df0..adc970c43afad 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -196,8 +196,9 @@ def load_model(self) -> None: self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.device, self.prompt_adapter_config) - self.model = self.prompt_adapter_manager\ - .create_prompt_adapter_manager(self.model) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors @@ -902,28 +903,28 @@ def profile_run(self) -> None: def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) + return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) + return self.lora_manager.remove_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() + return self.lora_manager.list_adapters() def remove_all_prompt_adapters(self): if not self.prompt_adapter_manager: @@ -935,25 +936,24 @@ def set_active_prompt_adapters( prompt_adapter_mapping: PromptAdapterMapping) -> None: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.set_active_prompt_adapters( + self.prompt_adapter_manager.set_active_adapters( prompt_adapter_requests, prompt_adapter_mapping) def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.add_prompt_adapter( - prompt_adapter_request) + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.remove_lora(prompt_adapter_id) + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) def list_prompt_adapters(self) -> Set[int]: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.list_prompt_adapters() + return self.prompt_adapter_manager.list_adapters() @torch.inference_mode() def capture_model(self, kv_caches: List[torch.Tensor]) -> None: From 3d14475f810ead267372250830b9d910c8dd4114 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:26:30 -0400 Subject: [PATCH 39/80] formatting --- vllm/prompt_adapter/models.py | 4 ++-- vllm/worker/xpu_model_runner.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 516503c3896ae..03090a446a7d9 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -104,8 +104,8 @@ def activate_adapter( raise ValueError("No free prompt_adapter slots") index, _ = first_free_slot self._active_adapters[prompt_adapter_id] = None - prompt_adapter_model = \ - self._registered_adapters[prompt_adapter_id] + prompt_adapter_model = ( + self._registered_adapters[prompt_adapter_id]) logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", prompt_adapter_model.id, index) self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f30de703e805d..bb2dd321755ac 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, PromptAdapterConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -37,6 +37,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -48,6 +49,7 @@ def __init__( self.load_config = load_config self.cache_config = cache_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker self.sliding_window = model_config.get_sliding_window() From 86e72decc9c4cbadbde946d0864a0b3a679e5101 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:26:45 -0400 Subject: [PATCH 40/80] formatting --- vllm/prompt_adapter/models.py | 3 +-- vllm/worker/xpu_model_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 03090a446a7d9..e36ab687829c5 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -104,8 +104,7 @@ def activate_adapter( raise ValueError("No free prompt_adapter slots") index, _ = first_free_slot self._active_adapters[prompt_adapter_id] = None - prompt_adapter_model = ( - self._registered_adapters[prompt_adapter_id]) + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", prompt_adapter_model.id, index) self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index bb2dd321755ac..bc27c13dd3119 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -5,8 +5,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model From 41934ccb8d0eae2e384ef13f367ecd97fcf240aa Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:34:27 -0400 Subject: [PATCH 41/80] xpu compatibility --- vllm/executor/xpu_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index d37200bd02de3..4af3a3de977a4 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -4,7 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, + PromptAdapterConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -27,6 +28,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -43,6 +45,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.prompt_adapter_config= prompt_adapter_config, self.speculative_config = None # Instantiate the worker and load the model to GPU. From fdfec592d9dd5889a10a1bb8237a6551b4ac02b1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:35:42 -0400 Subject: [PATCH 42/80] xpu compatibility --- vllm/executor/ray_xpu_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index dd7c82289341e..4174fc6bf0923 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, + PromptAdapterConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -44,6 +45,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -58,6 +60,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config placement_group = self.parallel_config.placement_group From 6b1f0e7df82575d30f74cf541297f4b52452a22b Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:35:59 -0400 Subject: [PATCH 43/80] xpu compatibility --- vllm/executor/ray_xpu_executor.py | 6 +++--- vllm/executor/xpu_executor.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 4174fc6bf0923..ebd38300cd1e9 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -7,9 +7,9 @@ Tuple, Union) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig, - PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 4af3a3de977a4..9fc6033510bc6 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -3,9 +3,9 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig, - PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -45,7 +45,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config - self.prompt_adapter_config= prompt_adapter_config, + self.prompt_adapter_config = prompt_adapter_config, self.speculative_config = None # Instantiate the worker and load the model to GPU. From 01bb713b9872e6689540d1037e076bcf966ed4cf Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 17:46:32 -0400 Subject: [PATCH 44/80] xpu compatibility --- vllm/worker/xpu_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 773ee9f8159e1..ee03368295781 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -9,8 +9,9 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -47,6 +48,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" @@ -62,6 +64,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." From d7312e29e9175c9649246b159970e9e52a23de89 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 19:57:53 -0400 Subject: [PATCH 45/80] formatting --- vllm/lora/models.py | 8 ++++---- vllm/worker/worker.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index bb0424f5b6417..ef4861531e7f2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -757,17 +757,17 @@ def pin_lora(self, lora_id: int) -> bool: def _pin_lora_in_cpu_cache(self, lora_id: int): try: - self._registered_loras.pin(lora_id) + self._registered_adapters.pin(lora_id) except ValueError as err: raise ValueError("Pinning failed. " f"LoRA {lora_id} is not registered.") from err def _pin_lora_in_gpu_cache(self, lora_id: int): - if lora_id not in self._active_loras: + if lora_id not in self._active_adapters: # move lora to gpu if not already active - self.activate_lora(lora_id) + self.activate_adapter(lora_id) - self._active_loras.pin(lora_id) + self._active_adapters.pin(lora_id) def create_lora_manager( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6c1a0307afb04..eeeefd02f64b0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -95,7 +95,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, - prompt_adapter_config=prompt_adapter_config) + prompt_adapter_config=prompt_adapter_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by From 454d45b897309dae7ed2e1f70a33581db97cd097 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 21:13:19 -0400 Subject: [PATCH 46/80] formatting + updating tests --- tests/lora/test_lora_manager.py | 44 ++++++++++++++-------------- tests/lora/test_worker.py | 8 ++--- tests/prompt_adapter/test_pa_lora.py | 30 +++++++++++-------- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 6ec80193ce923..e841a9d950cd0 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -212,13 +212,13 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): assert manager.pin_lora(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.activate_lora(1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.deactivate_lora(2) + assert manager.deactivate_adapter(2) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 assert manager.pin_lora(3) @@ -228,13 +228,13 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 with pytest.raises(RuntimeError): - assert manager.activate_lora(2) + assert manager.activate_adapter(2) - assert manager.deactivate_lora(3) + assert manager.deactivate_adapter(3) assert manager.pin_lora(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.remove_lora(3) + assert manager.remove_adapter(3) with pytest.raises(ValueError): assert manager.pin_lora(3) @@ -317,40 +317,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert all(x is None for x in manager.lora_index_to_id) # pinning - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) - assert set(manager.list_loras()) == {3, 4} + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) + assert set(manager.list_adapters()) == {3, 4} with pytest.raises(ValueError): assert manager.pin_lora(1) assert manager.pin_lora(3) # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {4} + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.add_lora(model_lora1) + assert manager.add_adapter(model_lora1) assert manager.pin_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {1} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {1} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] is None with pytest.raises(RuntimeError): - assert manager.remove_oldest_lora() + assert manager.remove_oldest_adapter() - assert set(manager.list_loras()) == {1} + assert set(manager.list_adapters()) == {1} def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 732e91a52c0a9..458447b9d2bb2 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -42,7 +42,7 @@ def test_worker_apply_lora(sql_lora_files): worker.init_device() worker.load_model() - worker.model_runner.set_active_loras([], LoRAMapping([], [])) + worker.model_runner.set_active_adapters([], LoRAMapping([], [])) assert worker.list_loras() == set() n_loras = 32 @@ -50,7 +50,7 @@ def test_worker_apply_lora(sql_lora_files): LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) ] - worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) + worker.model_runner.set_active_adapters(lora_requests, LoRAMapping([], [])) assert worker.list_loras() == { lora_request.lora_int_id for lora_request in lora_requests @@ -62,8 +62,8 @@ def test_worker_apply_lora(sql_lora_files): k=random.randint(1, n_loras)) random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] - worker.model_runner.set_active_loras(iter_lora_requests, - LoRAMapping([], [])) + worker.model_runner.set_active_adapters(iter_lora_requests, + LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests}) diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py index d3a755b30bc79..a440f26512ef6 100644 --- a/tests/prompt_adapter/test_pa_lora.py +++ b/tests/prompt_adapter/test_pa_lora.py @@ -1,24 +1,29 @@ +from huggingface_hub import snapshot_download + from vllm import EngineArgs, LLMEngine, SamplingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.lora.request import LoRARequest -from huggingface_hub import snapshot_download +from vllm.prompt_adapter.request import PromptAdapterRequest MODEL_PATH = "meta-llama/Llama-2-7b-hf" pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + def do_sample(engine): prompts = [ - ('Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label :', - SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), - PromptAdapterRequest("hate_speech", 1, pa_path, 8), - LoRARequest("sql_test", 1, lora_path)), - - ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", - SamplingParams(temperature=0.0, max_tokens=100, stop=["[/assistant]"]), - PromptAdapterRequest("hate_speech", 1, pa_path, 8), - LoRARequest("sql_test", 1, lora_path)), + ( + 'Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label :', # noqa: E501 + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech", 1, pa_path, 8), + LoRARequest("sql_test", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + max_tokens=100, + stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, 8), + LoRARequest("sql_test", 1, lora_path)), ] request_id = 0 @@ -49,6 +54,7 @@ def test_lora_prompt_adapter(): max_num_seqs=60) engine = LLMEngine.from_engine_args(engine_args) expected_output = { - " complaint", " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " + " complaint", + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 } assert do_sample(engine) == expected_output From 409dba11d011b5dbd9bdbe4a9e3d8d6566d8c26e Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 23 Jun 2024 21:14:52 -0400 Subject: [PATCH 47/80] test changes --- tests/lora/test_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 458447b9d2bb2..f6341754704b6 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -42,7 +42,7 @@ def test_worker_apply_lora(sql_lora_files): worker.init_device() worker.load_model() - worker.model_runner.set_active_adapters([], LoRAMapping([], [])) + worker.model_runner.set_active_loras([], LoRAMapping([], [])) assert worker.list_loras() == set() n_loras = 32 @@ -50,7 +50,7 @@ def test_worker_apply_lora(sql_lora_files): LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) ] - worker.model_runner.set_active_adapters(lora_requests, LoRAMapping([], [])) + worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) assert worker.list_loras() == { lora_request.lora_int_id for lora_request in lora_requests @@ -62,7 +62,7 @@ def test_worker_apply_lora(sql_lora_files): k=random.randint(1, n_loras)) random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] - worker.model_runner.set_active_adapters(iter_lora_requests, + worker.model_runner.set_active_loras(iter_lora_requests, LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id From ab95ad70903cd3d80393dc30d3f96780ad0f1218 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 00:02:37 -0400 Subject: [PATCH 48/80] cpu-gpu sync changes + adapter abstract changes --- requirements-common.txt | 1 - tests/lora/test_lora_manager.py | 18 +++--- tests/lora/test_worker.py | 2 +- tests/prompt_adapter/test_pa_lora.py | 28 +++++---- vllm/adapter_commons/models.py | 52 ++++++++--------- vllm/adapter_commons/request.py | 8 ++- vllm/adapter_commons/utils.py | 69 ++++++++++++++++++++++ vllm/adapter_commons/worker_manager.py | 47 ++++----------- vllm/config.py | 9 ++- vllm/engine/arg_utils.py | 8 ++- vllm/executor/cpu_executor.py | 3 + vllm/executor/executor_base.py | 4 ++ vllm/executor/gpu_executor.py | 4 ++ vllm/lora/models.py | 39 +++++++++---- vllm/lora/worker_manager.py | 25 +++++++- vllm/prompt_adapter/layers.py | 45 +++++++++----- vllm/prompt_adapter/models.py | 81 ++++++++++++++++++++++++-- vllm/prompt_adapter/worker_manager.py | 26 ++++++++- vllm/worker/worker.py | 5 +- 19 files changed, 348 insertions(+), 126 deletions(-) create mode 100644 vllm/adapter_commons/utils.py diff --git a/requirements-common.txt b/requirements-common.txt index 3ac561770298c..05969cfa5d65f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,4 +21,3 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -peft diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index e841a9d950cd0..7bff9e1fbcdcc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -209,7 +209,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 assert manager.activate_adapter(1) @@ -221,22 +221,22 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.pin_lora(3) - assert manager.pin_lora(1) + assert manager.pin_adapter(3) + assert manager.pin_adapter(1) with pytest.raises(RuntimeError): - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 with pytest.raises(RuntimeError): assert manager.activate_adapter(2) assert manager.deactivate_adapter(3) - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 assert manager.remove_adapter(3) with pytest.raises(ValueError): - assert manager.pin_lora(3) + assert manager.pin_adapter(3) def test_lru_lora_model_manager(dist_init, dummy_model): @@ -323,8 +323,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_adapter(4) assert set(manager.list_adapters()) == {3, 4} with pytest.raises(ValueError): - assert manager.pin_lora(1) - assert manager.pin_lora(3) + assert manager.pin_adapter(1) + assert manager.pin_adapter(3) # Remove manually assert manager.remove_adapter(3) assert not manager.remove_adapter(3) @@ -334,7 +334,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.lora_index_to_id[1] == 4 assert manager.add_adapter(model_lora1) - assert manager.pin_lora(1) + assert manager.pin_adapter(1) assert manager.add_adapter(model_lora2) assert manager.activate_adapter(2) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index f6341754704b6..ef6a4912b68fd 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -63,7 +63,7 @@ def test_worker_apply_lora(sql_lora_files): random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] worker.model_runner.set_active_loras(iter_lora_requests, - LoRAMapping([], [])) + LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests}) diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py index a440f26512ef6..4eab29ddced42 100644 --- a/tests/prompt_adapter/test_pa_lora.py +++ b/tests/prompt_adapter/test_pa_lora.py @@ -10,20 +10,27 @@ def do_sample(engine): - + + prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 + + # first prompt with a prompt adapter and second without adapter prompts = [ ( - 'Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label :', # noqa: E501 - SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + prompt_text, + SamplingParams(temperature=0.0, + max_tokens=100, + stop=["[/assistant]"]), PromptAdapterRequest("hate_speech", 1, pa_path, 8), - LoRARequest("sql_test", 1, lora_path)), + LoRARequest("sql_test", 1, lora_path) + ), ( - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + prompt_text, SamplingParams(temperature=0.0, max_tokens=100, stop=["[/assistant]"]), - PromptAdapterRequest("hate_speech", 1, pa_path, 8), - LoRARequest("sql_test", 1, lora_path)), + None, + LoRARequest("sql_test", 1, lora_path) + ), ] request_id = 0 @@ -43,7 +50,6 @@ def do_sample(engine): for request_output in request_outputs: if request_output.finished: results.add(request_output.outputs[0].text) - print(results) return results @@ -53,8 +59,10 @@ def test_lora_prompt_adapter(): enable_lora=True, max_num_seqs=60) engine = LLMEngine.from_engine_args(engine_args) + result = do_sample(engine) + expected_output = { - " complaint", + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501, " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 } - assert do_sample(engine) == expected_output + assert result == expected_output diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index e949331f7ba51..0f9c6867e88a4 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -66,45 +66,39 @@ def adapter_slots(self): @abstractmethod def capacity(self): ... - + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + ... + @abstractmethod - def _deactivate_adapter(self, adapter_id: int): - raise NotImplementedError("Subclasses must implement this method.") - def deactivate_adapter(self, adapter_id: int) -> bool: - if adapter_id in self._active_adapters: - self._deactivate_adapter(adapter_id) - self._active_adapters.pop(adapter_id) - return True - return False + ... @abstractmethod - def _add_adapter(self, adapter: Any): - raise NotImplementedError("Subclasses must implement this method.") - def add_adapter(self, adapter: Any) -> bool: - if adapter.id not in self._registered_adapters: - if len(self._registered_adapters) >= self.capacity: - raise RuntimeError(f'No free {self.adapter_type} slots.') - self._add_adapter(adapter) - return True - return False + ... + @abstractmethod def set_adapter_mapping(self, mapping: Any) -> None: - if self._last_mapping != mapping: - self._set_adapter_mapping(mapping) - self._last_mapping = mapping + ... @abstractmethod - def _set_adapter_mapping(self, mapping: Any) -> None: - raise NotImplementedError("Subclasses must implement this method.") - def remove_adapter(self, adapter_id: int) -> bool: - self.deactivate_adapter(adapter_id) - return bool(self._registered_adapters.pop(adapter_id, None)) + ... - def list_adapters(self) -> Dict[int, Any]: - return dict(self._registered_adapters) + @abstractmethod + def remove_all_adapters(self): + ... + @abstractmethod def get_adapter(self, adapter_id: int) -> Optional[Any]: - return self._registered_adapters.get(adapter_id, None) + ... + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + ... + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + ... \ No newline at end of file diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index cf59e29b3c7cd..31b32d766dba9 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,12 +1,16 @@ from dataclasses import dataclass - +from abc import abstractmethod @dataclass class AdapterRequest: """ Base class for adapter requests. """ - + @property + @abstractmethod + def adapter_id(self): + ... + def __post_init__(self): if self.adapter_id < 1: raise ValueError(f"id must be > 0, got {self.adapter_id}") diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000000000..6998fe47b2f5e --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,69 @@ +from typing import Dict, Optional, Any, Callable, Set + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError(f'No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + +def set_adapter_mapping(mapping: Any, last_mapping: Any, set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + +def get_adapter(adapter_id: int, registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id, None) + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + +def add_adapter_worker(adapter_request: Any, list_adapters_func, load_adapter_func, add_adapter_func, activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, adapter_slots: int, remove_adapter_func, add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + "than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) \ No newline at end of file diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index 516a1d57d9683..ae975343b1ec8 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -14,50 +14,23 @@ def __init__(self, device: torch.device): def is_enabled(self) -> bool: ... + @abstractmethod def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: - self._apply_adapters(requests) - self._adapter_manager.set_adapter_mapping(mapping) - - @abstractmethod - def _load_adapter(self, request: Any) -> Any: ... - + + @abstractmethod def add_adapter(self, adapter_request: Any) -> bool: - if adapter_request.adapter_id in self.list_adapters(): - return False - loaded_adapter = self._load_adapter(adapter_request) - loaded = self._adapter_manager.add_adapter(loaded_adapter) - self._adapter_manager.activate_adapter(loaded_adapter.id) - return loaded - - def _apply_adapters(self, adapter_requests: Set[Any]) -> None: - models_that_exist = self.list_adapters() - models_map = { - adapter_request.adapter_id: adapter_request - for adapter_request in adapter_requests if adapter_request - } - if len(models_map) > self._adapter_manager.adapter_slots: - raise RuntimeError( - f"Number of requested models ({len(models_map)}) is greater " - "than the number of GPU model slots " - f"({self._adapter_manager.adapter_slots}).") - - new_models = set(models_map) - models_to_add = new_models - models_that_exist - models_to_remove = models_that_exist - new_models - - for adapter_id in models_to_remove: - self.remove_adapter(adapter_id) - - for adapter_id in models_to_add: - self.add_adapter(models_map[adapter_id]) + ... + @abstractmethod def remove_adapter(self, adapter_id: int) -> bool: - return self._adapter_manager.remove_adapter(adapter_id) + ... + @abstractmethod def remove_all_adapters(self): - self._adapter_manager.remove_all_adapters() + ... + @abstractmethod def list_adapters(self) -> Set[int]: - return set(self._adapter_manager.list_adapters()) + ... \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 634035feb2779..78d9088c2e9dd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1135,10 +1135,17 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): @dataclass class PromptAdapterConfig: max_prompt_adapters: int + max_prompt_adapter_token: int = 10 max_cpu_prompt_adapters: Optional[int] = None prompt_adapter_dtype: Optional[torch.dtype] = None - + def __post_init__(self): + library_name = 'peft' + try: + __import__(library_name) + except ImportError: + raise ImportError(f"'{library_name}' is not installed for prompt adapter support. Please install it using 'pip install {library_name}'.") + if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " f"({self.max_prompt_adapters}) must be >= 1.") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index edb1ecd428d6f..3b3eb0573d4c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -69,6 +69,7 @@ class EngineArgs: max_lora_rank: int = 16 enable_prompt_adapter: bool = False max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 10 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None @@ -513,6 +514,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.max_prompt_adapters, help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') parser.add_argument( "--device", type=str, @@ -743,7 +748,8 @@ def create_engine_config(self, ) -> EngineConfig: ) prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters) \ + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ if self.enable_prompt_adapter else None if self.image_input_type: diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 0c6b57299a46a..953151e8adde6 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -102,6 +102,9 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: return self.driver_worker.list_prompt_adapters() + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 7af8eb4b49b50..45d46bbd6c711 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -102,6 +102,10 @@ def add_prompt_adapter( def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: raise NotImplementedError + @abstractmethod + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + @abstractmethod def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7cb8e71e48fd4..58cca657a8b4a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -119,6 +119,10 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: "prompt_adapter_id must be greater than 0." return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + def list_prompt_adapters(self) -> Set[int]: return self.driver_worker.list_prompt_adapters() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index ef4861531e7f2..0a167af3ff487 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union, Any import safetensors.torch import torch @@ -12,6 +12,8 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) +from vllm.adapter_commons.utils import (deactivate_adapter, add_adapter, + set_adapter_mapping, remove_adapter, list_adapters, get_adapter) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -424,9 +426,7 @@ def __init__( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} - # self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. - # self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -502,14 +502,7 @@ def _add_adapter(self, lora: LoRAModel): self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - def add_adapter(self, lora: LoRAModel): - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - return super().add_adapter(lora) - - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager." @@ -685,6 +678,28 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: replacement_loras[i] = None lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) + + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) class LoRALRUCache(AdapterLRUCache[LoRAModel]): @@ -749,7 +764,7 @@ def remove_oldest_adapter(self) -> bool: return True return False - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3ceda14902429..a9a2477e589ff 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -4,6 +4,10 @@ import torch from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.adapter_commons.utils import (set_active_adapters_worker, + add_adapter_worker, list_adapters_worker, + apply_adapters_worker) + from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, @@ -123,8 +127,25 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) - def pin_lora(self, lora_id: int) -> bool: - return self._adapter_manager.pin_lora(lora_id) + def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, self._load_adapter, + self._adapter_manager.add_adapter, self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 88b1f5fc6fd32..2058ede45330e 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import List, Optional import torch from torch import nn @@ -7,49 +7,66 @@ from vllm.adapter_commons.layers import AdapterMapping from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) - +from vllm.config import PromptAdapterConfig @dataclass class PromptAdapterMapping(AdapterMapping): pass - class VocabParallelEmbeddingWithPromptAdapter(nn.Module): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer - self.embedding_tensors: Dict[int, torch.Tensor] = {} + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights(self, + prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, device=self.emb_layer.weight.device) self.indices_gpu: torch.Tensor self.flag: bool = False def reset_prompt_adapter(self, index: int): - self.embedding_tensors[index] = 0 + self.embeddings_tensors[index] = 0 def set_prompt_adapter( self, index: int, - embeddings_tensor: Optional[torch.Tensor], + adapter_model: Optional[torch.Tensor], ): self.reset_prompt_adapter(index) - if embeddings_tensor is not None: - self.embedding_tensors[index] = embeddings_tensor + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length def set_mapping( self, base_indices: List[int], ): - self.indices_gpu = torch.tensor(base_indices, device="cuda") - self.flag = torch.sum(self.indices_gpu) > 0 - + self.indices_gpu = base_indices.to(device=self.emb_layer.weight.device) + self.flag = True if torch.sum(self.indices_gpu)/self.indices_gpu.shape[0]!=-1 else False + def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) if self.flag: unique_indices = torch.unique(self.indices_gpu) for idx in unique_indices: - if idx != 0: - pa_idx = self.embedding_tensors[ - idx.item()].prompt_embedding + if idx != -1: + pa_idx = self.embeddings_tensors[idx][:self.adapter_lengths[idx]] mask = (self.indices_gpu == idx) n_adapters = sum(mask) // pa_idx.shape[0] hidden_states[mask] = pa_idx.repeat(n_adapters, 1) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index e36ab687829c5..38ce448e99787 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -1,6 +1,6 @@ import logging import math -from typing import Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type import torch from peft.utils import load_peft_weights @@ -8,10 +8,15 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) +from vllm.adapter_commons.utils import (deactivate_adapter, add_adapter, + set_adapter_mapping, remove_adapter, list_adapters, get_adapter) + from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.layers import ( PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) + + logger = logging.getLogger(__name__) _GLOBAL_PROMPT_ADAPTER_ID = 0 @@ -22,6 +27,26 @@ def get_prompt_adapter_id(): _GLOBAL_PROMPT_ADAPTER_ID += 1 return _GLOBAL_PROMPT_ADAPTER_ID +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = {id_: idx for idx, id_ in enumerate(prompt_adapter_index_to_id) if id_ is not None} + pa_indices = [ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ] + return torch.tensor(pa_indices) class PromptAdapterModel(AdapterModel): @@ -71,7 +96,7 @@ def __init__( self.model.prompt_adapter_manager = self self.adapter_type = 'PromptAdapter' - self.base_indices = [0] + self.base_indices = torch.tensor([-1]) self.modules: Dict[str, nn.Module] = {} self._create_prompt_adapter_modules() self._last_mapping: Optional[PromptAdapterMapping] = None @@ -109,7 +134,7 @@ def activate_adapter( prompt_adapter_model.id, index) self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id for _, v in self.modules.items(): - v.set_prompt_adapter(prompt_adapter_id, prompt_adapter_model) + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) return True def _deactivate_adapter(self, prompt_adapter_id: int): @@ -117,7 +142,7 @@ def _deactivate_adapter(self, prompt_adapter_id: int): index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) self.prompt_adapter_index_to_id[index] = None for _, v in self.modules.items(): - v.reset_prompt_adapter(prompt_adapter_id) + v.reset_prompt_adapter(index) except ValueError: pass @@ -125,14 +150,17 @@ def _add_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices = convert_mapping(mapping, + self.prompt_adapter_index_to_id) for k, v in self.modules.items(): - v.set_mapping(mapping.index_mapping) + v.set_mapping(base_indices) def _create_prompt_adapter_modules(self): for module_name, module in self.model.named_modules( remove_duplicate=False): if "VocabParallel" in module.__class__.__name__: new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights(self.prompt_adapter_config) replaced_module = self.replace_submodule( self.model, module_name, new_module) self.register_module(module.__class__.__name__, @@ -151,11 +179,35 @@ def replace_submodule(self, model: nn.Module, module_name: str, def register_module(self, module_name: str, module: nn.Module): self.modules[module_name] = module + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager." + "Use LRUCachePromptAdapterModelManager for pinning") # type: ignore + def remove_all_adapters(self): """Remove all PromptAdapterModel from the manager.""" self._registered_adapters.clear() self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): @@ -216,6 +268,25 @@ def remove_oldest_adapter(self) -> bool: return True return False + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered.") from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if lora_id not in self._active_adapters: + # move lora to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + def create_prompt_adapter_manager( model: nn.Module, diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 7bfdeec15b3e8..835d8039befb9 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -1,9 +1,13 @@ import logging -from typing import Any, Set, Type +from typing import Any, Set, Type, Optional import torch from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.adapter_commons.utils import (set_active_adapters_worker, + add_adapter_worker, list_adapters_worker, + apply_adapters_worker) + from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, PromptAdapterModel, @@ -78,6 +82,26 @@ def add_dummy_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: return True + def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, self._load_adapter, + self._adapter_manager.add_adapter, self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): """WorkerPromptAdapterManager that manages diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index eeeefd02f64b0..605f9c07f9c00 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -339,7 +339,7 @@ def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) + return self.model_runner.pin_adapter(lora_id) def list_loras(self) -> Set[int]: return self.model_runner.list_loras() @@ -351,6 +351,9 @@ def add_prompt_adapter( def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.model_runner.remove_lora(prompt_adapter_id) + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_adapter(prompt_adapter_id) + def list_prompt_adapters(self) -> Set[int]: return self.model_runner.list_prompt_adapters() From 2faec6129ae9ba8e92809fcde0cb82ee4204d99a Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 00:17:44 -0400 Subject: [PATCH 49/80] formatting --- tests/lora/test_worker.py | 2 +- tests/prompt_adapter/test_pa_lora.py | 32 ++++++-------- vllm/adapter_commons/models.py | 6 +-- vllm/adapter_commons/request.py | 8 ++-- vllm/adapter_commons/utils.py | 47 +++++++++++++++------ vllm/adapter_commons/worker_manager.py | 4 +- vllm/config.py | 9 ++-- vllm/engine/arg_utils.py | 32 +++++++------- vllm/executor/cpu_executor.py | 2 +- vllm/executor/gpu_executor.py | 5 ++- vllm/lora/models.py | 27 ++++++------ vllm/lora/worker_manager.py | 25 ++++++----- vllm/prompt_adapter/layers.py | 27 +++++++----- vllm/prompt_adapter/models.py | 58 ++++++++++++++++---------- vllm/prompt_adapter/worker_manager.py | 29 +++++++------ 15 files changed, 181 insertions(+), 132 deletions(-) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index ef6a4912b68fd..732e91a52c0a9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -63,7 +63,7 @@ def test_worker_apply_lora(sql_lora_files): random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] worker.model_runner.set_active_loras(iter_lora_requests, - LoRAMapping([], [])) + LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests}) diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py index 4eab29ddced42..89f349fec6337 100644 --- a/tests/prompt_adapter/test_pa_lora.py +++ b/tests/prompt_adapter/test_pa_lora.py @@ -10,27 +10,20 @@ def do_sample(engine): - + prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 - + # first prompt with a prompt adapter and second without adapter prompts = [ - ( - prompt_text, - SamplingParams(temperature=0.0, - max_tokens=100, - stop=["[/assistant]"]), - PromptAdapterRequest("hate_speech", 1, pa_path, 8), - LoRARequest("sql_test", 1, lora_path) - ), - ( - prompt_text, - SamplingParams(temperature=0.0, - max_tokens=100, - stop=["[/assistant]"]), - None, - LoRARequest("sql_test", 1, lora_path) - ), + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, + 8), LoRARequest("sql_test", 1, lora_path)), + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), None, + LoRARequest("sql_test", 1, lora_path)), ] request_id = 0 @@ -60,9 +53,8 @@ def test_lora_prompt_adapter(): max_num_seqs=60) engine = LLMEngine.from_engine_args(engine_args) result = do_sample(engine) - + expected_output = { - " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501, " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 } assert result == expected_output diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index 0f9c6867e88a4..6939b1405f3e1 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -66,11 +66,11 @@ def adapter_slots(self): @abstractmethod def capacity(self): ... - + @abstractmethod def activate_adapter(self, adapter_id: int) -> bool: ... - + @abstractmethod def deactivate_adapter(self, adapter_id: int) -> bool: ... @@ -101,4 +101,4 @@ def list_adapters(self) -> Dict[int, Any]: @abstractmethod def pin_adapter(self, adapter_id: int) -> bool: - ... \ No newline at end of file + ... diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index 31b32d766dba9..69775ab7d4548 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,16 +1,18 @@ -from dataclasses import dataclass from abc import abstractmethod +from dataclasses import dataclass + @dataclass class AdapterRequest: """ Base class for adapter requests. """ - @property + + @property @abstractmethod def adapter_id(self): ... - + def __post_init__(self): if self.adapter_id < 1: raise ValueError(f"id must be > 0, got {self.adapter_id}") diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index 6998fe47b2f5e..6c5411f7d3d5c 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -1,44 +1,61 @@ -from typing import Dict, Optional, Any, Callable, Set +from typing import Any, Callable, Dict, Optional, Set + ## model functions -def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], deactivate_func: Callable) -> bool: +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: if adapter_id in active_adapters: deactivate_func(adapter_id) active_adapters.pop(adapter_id) return True return False -def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], capacity: int, add_func: Callable) -> bool: + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: if adapter.id not in registered_adapters: if len(registered_adapters) >= capacity: - raise RuntimeError(f'No free adapter slots.') + raise RuntimeError('No free adapter slots.') add_func(adapter) registered_adapters[adapter.id] = adapter return True return False -def set_adapter_mapping(mapping: Any, last_mapping: Any, set_mapping_func: Callable) -> Any: + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: if last_mapping != mapping: set_mapping_func(mapping) return mapping return last_mapping -def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], deactivate_func: Callable) -> bool: + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: deactivate_func(adapter_id) return bool(registered_adapters.pop(adapter_id, None)) + def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: return dict(registered_adapters) - -def get_adapter(adapter_id: int, registered_adapters: Dict[int, Any]) -> Optional[Any]: + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: return registered_adapters.get(adapter_id, None) + ## worker functions -def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: apply_adapters_func(requests) set_adapter_mapping_func(mapping) -def add_adapter_worker(adapter_request: Any, list_adapters_func, load_adapter_func, add_adapter_func, activate_adapter_func) -> bool: + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: if adapter_request.adapter_id in list_adapters_func(): return False loaded_adapter = load_adapter_func(adapter_request) @@ -46,7 +63,10 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func, load_adapter_fu activate_adapter_func(loaded_adapter.id) return loaded -def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, adapter_slots: int, remove_adapter_func, add_adapter_func) -> None: + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: models_that_exist = list_adapters_func() models_map = { adapter_request.adapter_id: adapter_request @@ -55,7 +75,7 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, adapte if len(models_map) > adapter_slots: raise RuntimeError( f"Number of requested models ({len(models_map)}) is greater " - "than the number of GPU model slots " + f"than the number of GPU model slots " f"({adapter_slots}).") new_models = set(models_map) models_to_add = new_models - models_that_exist @@ -65,5 +85,6 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, adapte for adapter_id in models_to_add: add_adapter_func(models_map[adapter_id]) + def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: - return set(adapter_manager_list_adapters_func()) \ No newline at end of file + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index ae975343b1ec8..acf18993af6d7 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -18,7 +18,7 @@ def is_enabled(self) -> bool: def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: ... - + @abstractmethod def add_adapter(self, adapter_request: Any) -> bool: ... @@ -33,4 +33,4 @@ def remove_all_adapters(self): @abstractmethod def list_adapters(self) -> Set[int]: - ... \ No newline at end of file + ... diff --git a/vllm/config.py b/vllm/config.py index 78d9088c2e9dd..ba74a39c5d2e2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1138,13 +1138,16 @@ class PromptAdapterConfig: max_prompt_adapter_token: int = 10 max_cpu_prompt_adapters: Optional[int] = None prompt_adapter_dtype: Optional[torch.dtype] = None - + def __post_init__(self): library_name = 'peft' try: __import__(library_name) - except ImportError: - raise ImportError(f"'{library_name}' is not installed for prompt adapter support. Please install it using 'pip install {library_name}'.") + except ImportError as e: + raise ImportError( + f"'{library_name}' is not installed for prompt adapter support." + f"Please install it using 'pip install {library_name}'." + ) from e if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b3eb0573d4c0..b37e42d6e827a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,9 +7,10 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig, ObservabilityConfig) + VisionLanguageConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser, str_to_int_tuple @@ -796,21 +797,18 @@ def create_engine_config(self, ) -> EngineConfig: "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") - return EngineConfig( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - speculative_config=speculative_config, - load_config=load_config, - decoding_config=decoding_config, - observability_config=observability_config, - prompt_adapter_config=prompt_adapter_config - ) - + return EngineConfig(model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config) @dataclass diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 953151e8adde6..0b2507480be9f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -104,7 +104,7 @@ def list_prompt_adapters(self) -> Set[int]: def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) - + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 58cca657a8b4a..21c37c972a173 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -120,9 +120,10 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) - + def list_prompt_adapters(self) -> Set[int]: return self.driver_worker.list_prompt_adapters() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 0a167af3ff487..2516b10d4f857 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Type, Union, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch @@ -12,8 +12,9 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) -from vllm.adapter_commons.utils import (deactivate_adapter, add_adapter, - set_adapter_mapping, remove_adapter, list_adapters, get_adapter) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -479,9 +480,6 @@ def _deactivate_adapter(self, lora_id: int): except ValueError: pass - def deactivate_adapter(self, lora_id: int) -> bool: - return super().deactivate_adapter(lora_id) - def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -678,22 +676,27 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: replacement_loras[i] = None lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) - + def deactivate_adapter(self, adapter_id: int) -> bool: - return deactivate_adapter(adapter_id, self._active_adapters, self._deactivate_adapter) + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) def add_adapter(self, adapter: LoRAModel) -> bool: logger.debug( "Adding lora. Model id: %d, " "int id: %d, " - "scaling factor: %s", adapter.id, adapter.id, adapter.scaling_factor) - return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) def set_adapter_mapping(self, mapping: LoRAMapping) -> None: - self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, self._set_adapter_mapping) + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) def remove_adapter(self, adapter_id: int) -> bool: - return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) def list_adapters(self) -> Dict[int, Any]: return list_adapters(self._registered_adapters) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index a9a2477e589ff..829abeec2dea3 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -3,11 +3,11 @@ import torch +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.adapter_commons.utils import (set_active_adapters_worker, - add_adapter_worker, list_adapters_worker, - apply_adapters_worker) - from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, @@ -127,16 +127,21 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) - def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: - set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) def _apply_adapters(self, adapter_requests: Set[Any]) -> None: - apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, - self.remove_adapter, self.add_adapter) + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) def add_adapter(self, adapter_request: Any) -> bool: - return add_adapter_worker(adapter_request, self.list_adapters, self._load_adapter, - self._adapter_manager.add_adapter, self._adapter_manager.activate_adapter) + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) def remove_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.remove_adapter(adapter_id) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 2058ede45330e..0bfc4573c9c36 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -1,18 +1,20 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Optional import torch from torch import nn from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.config import PromptAdapterConfig + @dataclass class PromptAdapterMapping(AdapterMapping): pass + class VocabParallelEmbeddingWithPromptAdapter(nn.Module): def __init__(self, base_layer: VocabParallelEmbedding) -> None: @@ -21,9 +23,9 @@ def __init__(self, base_layer: VocabParallelEmbedding) -> None: self.emb_layer = self.base_layer if 'LoRA' in base_layer.__class__.__name__: self.emb_layer = self.base_layer.base_layer - - def create_prompt_adapter_weights(self, - prompt_adapter_config: PromptAdapterConfig): + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): self.embeddings_tensors = torch.zeros( ( prompt_adapter_config.max_prompt_adapters, @@ -34,8 +36,9 @@ def create_prompt_adapter_weights(self, device=self.emb_layer.weight.device, ) self.adapter_lengths = torch.zeros( - prompt_adapter_config.max_prompt_adapters, - dtype=torch.long, device=self.emb_layer.weight.device) + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) self.indices_gpu: torch.Tensor self.flag: bool = False @@ -55,18 +58,20 @@ def set_prompt_adapter( def set_mapping( self, - base_indices: List[int], + base_indices: torch.Tensor, ): self.indices_gpu = base_indices.to(device=self.emb_layer.weight.device) - self.flag = True if torch.sum(self.indices_gpu)/self.indices_gpu.shape[0]!=-1 else False - + self.flag = torch.sum( + self.indices_gpu) / self.indices_gpu.shape[0] != -1 + def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) if self.flag: unique_indices = torch.unique(self.indices_gpu) for idx in unique_indices: if idx != -1: - pa_idx = self.embeddings_tensors[idx][:self.adapter_lengths[idx]] + pa_idx = self.embeddings_tensors[idx][:self. + adapter_lengths[idx]] mask = (self.indices_gpu == idx) n_adapters = sum(mask) // pa_idx.shape[0] hidden_states[mask] = pa_idx.repeat(n_adapters, 1) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 38ce448e99787..c5785a37f8aa9 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -8,15 +8,13 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, AdapterModelManager) -from vllm.adapter_commons.utils import (deactivate_adapter, add_adapter, - set_adapter_mapping, remove_adapter, list_adapters, get_adapter) - +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.layers import ( PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) - - logger = logging.getLogger(__name__) _GLOBAL_PROMPT_ADAPTER_ID = 0 @@ -27,6 +25,7 @@ def get_prompt_adapter_id(): _GLOBAL_PROMPT_ADAPTER_ID += 1 return _GLOBAL_PROMPT_ADAPTER_ID + def convert_mapping( mapping: PromptAdapterMapping, prompt_adapter_index_to_id: List[Optional[int]], @@ -34,20 +33,27 @@ def convert_mapping( """Converts PromptAdapterMapping to index tensors. Args: - mapping: PromptAdapterMapping mapping rows in a batch to PromptAdapter ids. - prompt_adapter_index_to_id: List mapping PromptAdapter ids to PromptAdapter indices. + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. Returns: pa_indices: Tensor of shape [batch_size] mapping batch rows to PromptAdapter indices. """ - id_to_index = {id_: idx for idx, id_ in enumerate(prompt_adapter_index_to_id) if id_ is not None} + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } pa_indices = [ id_to_index.get(id_, -1) if id_ > 0 else -1 for id_ in mapping.index_mapping ] return torch.tensor(pa_indices) + class PromptAdapterModel(AdapterModel): def __init__(self, @@ -150,7 +156,7 @@ def _add_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - base_indices = convert_mapping(mapping, + base_indices = convert_mapping(mapping, self.prompt_adapter_index_to_id) for k, v in self.modules.items(): v.set_mapping(base_indices) @@ -160,7 +166,8 @@ def _create_prompt_adapter_modules(self): remove_duplicate=False): if "VocabParallel" in module.__class__.__name__: new_module = VocabParallelEmbeddingWithPromptAdapter(module) - new_module.create_prompt_adapter_weights(self.prompt_adapter_config) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) replaced_module = self.replace_submodule( self.model, module_name, new_module) self.register_module(module.__class__.__name__, @@ -183,25 +190,30 @@ def pin_adapter(self, prompt_adapter_id: int) -> bool: """Pin a PromptAdapterModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in PromptAdapterModelManager." - "Use LRUCachePromptAdapterModelManager for pinning") # type: ignore - + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + def remove_all_adapters(self): """Remove all PromptAdapterModel from the manager.""" self._registered_adapters.clear() self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots self._active_adapters.clear() - + def deactivate_adapter(self, adapter_id: int) -> bool: - return deactivate_adapter(adapter_id, self._active_adapters, self._deactivate_adapter) + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) def add_adapter(self, adapter: PromptAdapterModel) -> bool: - return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, self._set_adapter_mapping) + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) def remove_adapter(self, adapter_id: int) -> bool: - return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) def list_adapters(self) -> Dict[int, Any]: return list_adapters(self._registered_adapters) @@ -213,7 +225,7 @@ def get_adapter(self, adapter_id: int) -> Optional[Any]: class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): def __init__(self, capacity: int, - deactivate_prompt_adapter_fn: Callable[[int], None]): + deactivate_prompt_adapter_fn: Callable[[int], bool]): super().__init__(capacity, deactivate_prompt_adapter_fn) @@ -278,12 +290,14 @@ def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): try: self._registered_adapters.pin(prompt_adapter_id) except ValueError as err: - raise ValueError("Pinning failed. " - f"Prompt Adapter {prompt_adapter_id} is not registered.") from err + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): - if lora_id not in self._active_adapters: - # move lora to gpu if not already active + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active self.activate_adapter(prompt_adapter_id) self._active_adapters.pin(prompt_adapter_id) diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 835d8039befb9..49e2be00e3f41 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -1,13 +1,13 @@ import logging -from typing import Any, Set, Type, Optional +from typing import Any, Optional, Set, Type import torch +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.adapter_commons.utils import (set_active_adapters_worker, - add_adapter_worker, list_adapters_worker, - apply_adapters_worker) - from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, PromptAdapterModel, @@ -82,16 +82,21 @@ def add_dummy_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: return True - def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: - set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) def add_adapter(self, adapter_request: Any) -> bool: - return add_adapter_worker(adapter_request, self.list_adapters, self._load_adapter, - self._adapter_manager.add_adapter, self._adapter_manager.activate_adapter) + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) def _apply_adapters(self, adapter_requests: Set[Any]) -> None: - apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, - self.remove_adapter, self.add_adapter) + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) def remove_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.remove_adapter(adapter_id) @@ -159,7 +164,7 @@ def add_adapter(self, # If the prompt_adapter is already loaded, just touch it to # update its position in the caches loaded = self._adapter_manager.get_adapter( - prompt_adapter_request.prompt_adapter_id) + prompt_adapter_request.prompt_adapter_id) is not None self._adapter_manager.activate_adapter( prompt_adapter_request.prompt_adapter_id) return loaded From 6955301a0a745d78333f080c39395f9c7b449c50 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 00:50:50 -0400 Subject: [PATCH 50/80] rebase --- vllm/lora/worker_manager.py | 3 +++ vllm/prompt_adapter/worker_manager.py | 3 +++ vllm/worker/model_runner.py | 26 ++++++++++++++++---------- vllm/worker/worker.py | 7 +++---- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 829abeec2dea3..3d0ef4252b024 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -127,6 +127,9 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: set_active_adapters_worker(requests, mapping, self._apply_adapters, diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index 49e2be00e3f41..ab72e2ba83163 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -82,6 +82,9 @@ def add_dummy_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: return True + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: set_active_adapters_worker(requests, mapping, self._apply_adapters, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e97545129a4b3..b58c3beb00877 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -56,6 +55,7 @@ TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") + @dataclasses.dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): """ @@ -73,7 +73,7 @@ class ModelInputForGPU(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None prompt_adapter_mapping: Optional[PromptAdapterMapping] = None - prompt_adapter_requests: Set[PromptAdapterRequest] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -82,7 +82,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping" : self.prompt_adapter_mapping, + "prompt_adapter_mapping": self.prompt_adapter_mapping, "prompt_adapter_requests": self.prompt_adapter_requests, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -117,7 +117,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping" : self.prompt_adapter_mapping, + "prompt_adapter_mapping": self.prompt_adapter_mapping, "prompt_adapter_requests": self.prompt_adapter_requests, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -265,7 +265,7 @@ def load_model(self) -> None: self.prompt_adapter_config) self.model = ( self.prompt_adapter_manager.create_prompt_adapter_manager( - self.model)) + self.model)) if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors @@ -881,7 +881,7 @@ def list_loras(self) -> Set[int]: def remove_all_prompt_adapters(self): if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.remove_all_prompt_adapters() + self.prompt_adapter_manager.remove_all_adapters() def set_active_prompt_adapters( self, prompt_adapter_requests: Set[PromptAdapterRequest], @@ -902,6 +902,11 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + def list_prompt_adapters(self) -> Set[int]: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") @@ -1063,13 +1068,14 @@ def execute_model( assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - + if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters(model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) - + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 62f1f2fe1ad63..a6948b403579b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,8 +10,7 @@ ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest @@ -285,7 +284,7 @@ def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_adapter(lora_id) + return self.model_runner.pin_lora(lora_id) def list_loras(self) -> Set[int]: return self.model_runner.list_loras() @@ -298,7 +297,7 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.model_runner.remove_lora(prompt_adapter_id) def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.pin_adapter(prompt_adapter_id) + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) def list_prompt_adapters(self) -> Set[int]: return self.model_runner.list_prompt_adapters() From 2814aee3dd03b70f191bf0a5d2cd752fbde88151 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 01:25:10 -0400 Subject: [PATCH 51/80] peft fix --- vllm/prompt_adapter/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index c5785a37f8aa9..b6171990138f8 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Type import torch -from peft.utils import load_peft_weights from torch import nn from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, @@ -71,6 +70,8 @@ def from_local_checkpoint( prompt_adapter_id: int, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> "PromptAdapterModel": + from peft.utils import load_peft_weights + adapters_weights = load_peft_weights(adapter_model_path, device) prompt_embedding = adapters_weights["prompt_embeddings"].to(dtype) num_virtual_tokens = prompt_embedding.shape[0] From 0e45660ceda8591e7bbc968eab09c5b0b7df102b Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 02:01:15 -0400 Subject: [PATCH 52/80] minor fix --- vllm/prompt_adapter/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 0bfc4573c9c36..73f9951007127 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -61,8 +61,8 @@ def set_mapping( base_indices: torch.Tensor, ): self.indices_gpu = base_indices.to(device=self.emb_layer.weight.device) - self.flag = torch.sum( - self.indices_gpu) / self.indices_gpu.shape[0] != -1 + self.flag = (torch.sum( + self.indices_gpu) / self.indices_gpu.shape[0] != -1).item() def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) From d58e3553bc8e5603988d57b4c34c1097b9876c07 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 26 Jun 2024 02:03:14 -0400 Subject: [PATCH 53/80] formatting --- vllm/prompt_adapter/layers.py | 4 ++-- vllm/worker/cpu_model_runner.py | 1 - vllm/worker/cpu_worker.py | 3 +-- vllm/worker/embedding_model_runner.py | 3 --- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 73f9951007127..7df0893c89669 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -61,8 +61,8 @@ def set_mapping( base_indices: torch.Tensor, ): self.indices_gpu = base_indices.to(device=self.emb_layer.weight.device) - self.flag = (torch.sum( - self.indices_gpu) / self.indices_gpu.shape[0] != -1).item() + self.flag = (torch.sum(self.indices_gpu) / self.indices_gpu.shape[0] != + -1).item() def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fe7431edb48fe..a4411c5b78347 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -9,7 +9,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 70230b35c7922..df3175797efe7 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -8,8 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index f1b1a8b6bc2d4..a3242977b9f87 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -6,12 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU From d7003241164ed6e65ab8d4402bf11a7cdd7457e1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 30 Jun 2024 19:26:33 -0400 Subject: [PATCH 54/80] forward update --- vllm/engine/async_llm_engine.py | 2 ++ vllm/engine/llm_engine.py | 2 ++ vllm/prompt_adapter/layers.py | 30 +++++++++++++-------------- vllm/prompt_adapter/models.py | 29 ++++++++++++++++++++------ vllm/worker/embedding_model_runner.py | 5 +++-- vllm/worker/model_runner.py | 7 +++++++ 6 files changed, 52 insertions(+), 23 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 58b16542e0323..d88cf7ea21add 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -648,6 +648,8 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6ebd9937d1d04..99332c03acbe7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -273,6 +273,8 @@ def __init__( # Feature flags "enable_lora": bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 7df0893c89669..52f32689a544b 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -39,8 +39,9 @@ def create_prompt_adapter_weights( prompt_adapter_config.max_prompt_adapters, dtype=torch.long, device=self.emb_layer.weight.device) + self.indices_gpu: torch.Tensor - self.flag: bool = False + self.embedding_indices_gpu: torch.Tensor def reset_prompt_adapter(self, index: int): self.embeddings_tensors[index] = 0 @@ -58,21 +59,20 @@ def set_prompt_adapter( def set_mapping( self, - base_indices: torch.Tensor, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, ): - self.indices_gpu = base_indices.to(device=self.emb_layer.weight.device) - self.flag = (torch.sum(self.indices_gpu) / self.indices_gpu.shape[0] != - -1).item() + self.indices_gpu = prompt_indices.to(device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to(device=self.emb_layer.weight.device) + def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) - if self.flag: - unique_indices = torch.unique(self.indices_gpu) - for idx in unique_indices: - if idx != -1: - pa_idx = self.embeddings_tensors[idx][:self. - adapter_lengths[idx]] - mask = (self.indices_gpu == idx) - n_adapters = sum(mask) // pa_idx.shape[0] - hidden_states[mask] = pa_idx.repeat(n_adapters, 1) - return hidden_states + if self.embedding_indices_gpu.numel(): + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[self.embedding_indices_gpu[:,0], + self.embedding_indices_gpu[:,1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index b6171990138f8..b9a8dcc695eff 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -24,6 +24,18 @@ def get_prompt_adapter_id(): _GLOBAL_PROMPT_ADAPTER_ID += 1 return _GLOBAL_PROMPT_ADAPTER_ID +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) def convert_mapping( mapping: PromptAdapterMapping, @@ -46,11 +58,13 @@ def convert_mapping( for idx, id_ in enumerate(prompt_adapter_index_to_id) if id_ is not None } - pa_indices = [ + pa_indices = torch.tensor([ id_to_index.get(id_, -1) if id_ > 0 else -1 for id_ in mapping.index_mapping - ] - return torch.tensor(pa_indices) + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + return pa_indices, pa_embedding_mapping class PromptAdapterModel(AdapterModel): @@ -104,6 +118,8 @@ def __init__( self.adapter_type = 'PromptAdapter' self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([-1]) + self.modules: Dict[str, nn.Module] = {} self._create_prompt_adapter_modules() self._last_mapping: Optional[PromptAdapterMapping] = None @@ -157,10 +173,10 @@ def _add_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - base_indices = convert_mapping(mapping, + base_indices, base_embedding_indices = convert_mapping(mapping, self.prompt_adapter_index_to_id) for k, v in self.modules.items(): - v.set_mapping(base_indices) + v.set_mapping(base_indices, base_embedding_indices) def _create_prompt_adapter_modules(self): for module_name, module in self.model.named_modules( @@ -173,7 +189,8 @@ def _create_prompt_adapter_modules(self): self.model, module_name, new_module) self.register_module(module.__class__.__name__, replaced_module) - replaced_module.set_mapping(self.base_indices) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) break def replace_submodule(self, model: nn.Module, module_name: str, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a3242977b9f87..3010bae15e01f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -69,8 +69,9 @@ def execute_model( if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters(model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b58c3beb00877..701f6795a64bf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -986,6 +986,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) + if self.prompt_adapter_config: + pa_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters(set(), pa_mapping) + graph_runner = CUDAGraphRunner(self.model) hidden_states = graph_runner.capture( input_tokens[:batch_size], From a5610a7d1c88c57fd1ea32589eb9b81d030e7466 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 30 Jun 2024 19:47:08 -0400 Subject: [PATCH 55/80] formatting --- vllm/engine/llm_engine.py | 2 +- vllm/prompt_adapter/layers.py | 18 ++++++++++-------- vllm/prompt_adapter/models.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 99332c03acbe7..820ae90bf15bb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1030,7 +1030,7 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_executor.list_loras() - + def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 52f32689a544b..07aa015d82572 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -39,7 +39,7 @@ def create_prompt_adapter_weights( prompt_adapter_config.max_prompt_adapters, dtype=torch.long, device=self.emb_layer.weight.device) - + self.indices_gpu: torch.Tensor self.embedding_indices_gpu: torch.Tensor @@ -62,17 +62,19 @@ def set_mapping( prompt_indices: torch.Tensor, prompt_embedding_indices: torch.Tensor, ): - self.indices_gpu = prompt_indices.to(device=self.emb_layer.weight.device) - self.embedding_indices_gpu = prompt_embedding_indices.to(device=self.emb_layer.weight.device) - + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) if self.embedding_indices_gpu.numel(): valid_mask = self.indices_gpu != -1 - gathered_embeddings = self.embeddings_tensors[self.embedding_indices_gpu[:,0], - self.embedding_indices_gpu[:,1]] - + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + # Update hidden states hidden_states[valid_mask] = gathered_embeddings - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index b9a8dcc695eff..acd878dc9b9a5 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -24,19 +24,21 @@ def get_prompt_adapter_id(): _GLOBAL_PROMPT_ADAPTER_ID += 1 return _GLOBAL_PROMPT_ADAPTER_ID + def convert_to_embedding_indices(indices): embedding_indices = [] count = 0 - + for value in indices: if value == -1: count = 0 else: embedding_indices.append([value, count]) count += 1 - + return torch.tensor(embedding_indices) + def convert_mapping( mapping: PromptAdapterMapping, prompt_adapter_index_to_id: List[Optional[int]], @@ -85,7 +87,7 @@ def from_local_checkpoint( device: str = "cuda", dtype: Optional[torch.dtype] = None) -> "PromptAdapterModel": from peft.utils import load_peft_weights - + adapters_weights = load_peft_weights(adapter_model_path, device) prompt_embedding = adapters_weights["prompt_embeddings"].to(dtype) num_virtual_tokens = prompt_embedding.shape[0] @@ -173,8 +175,8 @@ def _add_adapter(self, prompt_adapter: PromptAdapterModel): self._registered_adapters[prompt_adapter.id] = prompt_adapter def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - base_indices, base_embedding_indices = convert_mapping(mapping, - self.prompt_adapter_index_to_id) + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) for k, v in self.modules.items(): v.set_mapping(base_indices, base_embedding_indices) @@ -189,7 +191,7 @@ def _create_prompt_adapter_modules(self): self.model, module_name, new_module) self.register_module(module.__class__.__name__, replaced_module) - replaced_module.set_mapping(self.base_indices, + replaced_module.set_mapping(self.base_indices, self.base_embedding_indices) break From 8b6e827e9677eaf5d2e19ca74fa63c475ee937bd Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 30 Jun 2024 20:42:11 -0400 Subject: [PATCH 56/80] formatting --- vllm/worker/model_runner.py | 38 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 623b99d1115fc..df5c6acd47311 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -534,22 +534,22 @@ def _prepare_model_input_tensors( mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) - + if prompt_adapter_id > 0: - prompt_adapter_requests.add( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.\ - prompt_adapter_num_virtual_tokens - pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - prompt_adapter_index_mapping += pm - prompt_adapter_prompt_mapping.extend( - [prompt_adapter_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) - + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + is_profile_run = _is_block_tables_empty( seq_group_metadata.block_tables) if is_profile_run: @@ -1072,14 +1072,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: [0] * batch_size, ) self.set_active_loras(set(), lora_mapping) - + if self.prompt_adapter_config: prompt_adapter_mapping = PromptAdapterMapping( [-1] * batch_size, [-1] * batch_size, ) - self.set_active_prompt_adapters(set(), - prompt_adapter_mapping) + self.set_active_prompt_adapters(set(), + prompt_adapter_mapping) graph_runner = CUDAGraphRunner(self.model, self.attn_backend.get_name()) @@ -1187,7 +1187,7 @@ def execute_model( self.set_active_prompt_adapters( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - + if self.attn_backend.get_name() == "flashinfer": assert model_input.attn_metadata is not None assert model_input.input_tokens is not None From b83b6f02c0be1a18d3d054e36e3c9eb88ddc2d4b Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 1 Jul 2024 11:51:27 -0400 Subject: [PATCH 57/80] spec decode fix --- vllm/lora/models.py | 2 +- vllm/spec_decode/draft_model_runner.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index f03a1d31881e3..e1ede7d4d710a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import LRUCache, is_pin_memory_available +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index f30d29376121a..b65e1288b1c39 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,8 +3,8 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, @@ -47,6 +47,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, ): if return_hidden_states: @@ -65,6 +66,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, ) @@ -130,6 +132,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + outputs: List[SamplerOutput] = [] for step in range(num_steps): # Currently cuda graph is only supported by the decode phase. From 791ffbd8ad49b3f0baca9791592e5306850ed48d Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 2 Jul 2024 15:17:20 -0400 Subject: [PATCH 58/80] formatting --- vllm/engine/llm_engine.py | 1 - vllm/prompt_adapter/models.py | 5 +++-- vllm/sequence.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 240e59f76df29..04d4cec8ea0ad 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -555,7 +555,6 @@ def process_model_inputs( return self.input_processor(llm_inputs) - def add_request( self, request_id: str, diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index acd878dc9b9a5..d42d678e9c062 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -11,8 +11,9 @@ get_adapter, list_adapters, remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import ( - PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) +from vllm.prompt_adapter.layers import (PromptAdapterMapping, + VocabParallelEmbeddingWithPromptAdapter + ) logger = logging.getLogger(__name__) diff --git a/vllm/sequence.py b/vllm/sequence.py index c9e161a60b040..6d7754ab4f3d6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -242,13 +242,13 @@ class Sequence: """ def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs From 215947d299e8d97b8c94865f285a6101f0179d98 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 2 Jul 2024 15:23:41 -0400 Subject: [PATCH 59/80] async executor --- vllm/executor/executor_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 4660f9b612dc5..3dc16fb2bb95e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -138,6 +138,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], ) -> None: # This locks each pipeline parallel stage so multiple virtual engines # can't execute on the same stage at the same time @@ -149,7 +150,7 @@ def __init__( super().__init__(model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, lora_config, vision_language_config, - speculative_config) + speculative_config, prompt_adapter_config) @abstractmethod async def execute_model_async( From 9ae47e830d29dd9d0b68a214a55a84745fa160d1 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 2 Jul 2024 15:37:39 -0400 Subject: [PATCH 60/80] formatting --- vllm/prompt_adapter/models.py | 5 ++--- vllm/worker/model_runner.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index d42d678e9c062..acd878dc9b9a5 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -11,9 +11,8 @@ get_adapter, list_adapters, remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import (PromptAdapterMapping, - VocabParallelEmbeddingWithPromptAdapter - ) +from vllm.prompt_adapter.layers import ( + PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) logger = logging.getLogger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c93d49d592e30..d5173be50c337 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,8 +24,8 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig, PromptAdapterConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY From 3a2b545a71423ed8c502e413ffd50a3faac2c934 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 2 Jul 2024 15:45:50 -0400 Subject: [PATCH 61/80] formatting --- vllm/prompt_adapter/models.py | 5 +++-- vllm/worker/model_runner.py | 14 +++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index acd878dc9b9a5..d42d678e9c062 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -11,8 +11,9 @@ get_adapter, list_adapters, remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import ( - PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) +from vllm.prompt_adapter.layers import (PromptAdapterMapping, + VocabParallelEmbeddingWithPromptAdapter + ) logger = logging.getLogger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d5173be50c337..9c41d8087bf45 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1102,14 +1102,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: [0] * batch_size, ) self.set_active_loras(set(), lora_mapping) - + if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - [-1] * batch_size, - [-1] * batch_size, - ) - self.set_active_prompt_adapters(set(), - prompt_adapter_mapping) + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters( + set(), prompt_adapter_mapping) graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name()) From bbaea88d902cb62b9df3f69077e4b0df088fd510 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 2 Jul 2024 18:39:14 -0400 Subject: [PATCH 62/80] formatting --- vllm/prompt_adapter/models.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index d42d678e9c062..acd878dc9b9a5 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -11,9 +11,8 @@ get_adapter, list_adapters, remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import (PromptAdapterMapping, - VocabParallelEmbeddingWithPromptAdapter - ) +from vllm.prompt_adapter.layers import ( + PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) logger = logging.getLogger(__name__) From cdcea67c07dc93eaf11eab960383e8436f906382 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 3 Jul 2024 09:01:23 -0400 Subject: [PATCH 63/80] formatting --- vllm/worker/model_runner.py | 6 +++--- vllm/worker/worker.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 577101061d2b6..174d5818e4f05 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -37,12 +37,12 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import supports_lora +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, - MultiModalInputs) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -813,7 +813,7 @@ def _prepare_model_input_tensors( prompt_adapter_mapping=prompt_adapter_mapping, prompt_adapter_requests=prompt_adapter_requests, ) - + @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 43e047205ea26..1eae790174224 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -16,8 +16,8 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner From e771d43d89453a013f28ad0c052d8e94008262dc Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 3 Jul 2024 10:08:30 -0400 Subject: [PATCH 64/80] max_prompt_adapter_token defaults + error messages --- vllm/config.py | 6 ++++-- vllm/engine/arg_utils.py | 2 +- vllm/prompt_adapter/models.py | 22 +++++++++++++++------- vllm/prompt_adapter/worker_manager.py | 5 ++++- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ecd7e04517836..064931dc49005 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1295,11 +1295,10 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): raise ValueError("LoRA is not supported with chunked prefill yet.") -# TODO: To be replaced by MultiModalConfig. @dataclass class PromptAdapterConfig: max_prompt_adapters: int - max_prompt_adapter_token: int = 10 + max_prompt_adapter_token: int max_cpu_prompt_adapters: Optional[int] = None prompt_adapter_dtype: Optional[torch.dtype] = None @@ -1316,6 +1315,8 @@ def __post_init__(self): if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") if self.max_cpu_prompt_adapters is None: self.max_cpu_prompt_adapters = self.max_prompt_adapters @@ -1327,6 +1328,7 @@ def verify_with_model_config(self, model_config: ModelConfig): self.prompt_adapter_dtype) +# TODO: To be replaced by MultiModalConfig. @dataclass class VisionLanguageConfig: """Configs the input data format and how models should run for diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 07302b1602690..8534540e8565e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -69,7 +69,7 @@ class EngineArgs: max_lora_rank: int = 16 enable_prompt_adapter: bool = False max_prompt_adapters: int = 1 - max_prompt_adapter_token: int = 10 + max_prompt_adapter_token: int = 0 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index acd878dc9b9a5..d1c7a4fbb3415 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -81,16 +81,24 @@ def __init__(self, @classmethod def from_local_checkpoint( - cls, - adapter_model_path: str, - prompt_adapter_id: int, - device: str = "cuda", - dtype: Optional[torch.dtype] = None) -> "PromptAdapterModel": + cls, + adapter_model_path: str, + prompt_adapter_id: int, + num_virtual_tokens: int, + config: PromptAdapterConfig, + device: str = "cuda", + ) -> "PromptAdapterModel": from peft.utils import load_peft_weights + if num_virtual_tokens > config.max_prompt_adapter_token: + raise ValueError( + f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' + f'max_prompt_adapter_token({config.max_prompt_adapter_token})') + adapters_weights = load_peft_weights(adapter_model_path, device) - prompt_embedding = adapters_weights["prompt_embeddings"].to(dtype) - num_virtual_tokens = prompt_embedding.shape[0] + prompt_embedding = adapters_weights["prompt_embeddings"].to( + config.prompt_adapter_dtype) + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py index ab72e2ba83163..ddc1ef893c6f2 100644 --- a/vllm/prompt_adapter/worker_manager.py +++ b/vllm/prompt_adapter/worker_manager.py @@ -69,8 +69,11 @@ def _load_adapter( self._prompt_adapter_model_cls.from_local_checkpoint( prompt_adapter_request.prompt_adapter_local_path, prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + num_virtual_tokens=prompt_adapter_request. + prompt_adapter_num_virtual_tokens, + config=self.prompt_adapter_config, device=str(self.device), - dtype=self.prompt_adapter_config.prompt_adapter_dtype)) + )) except Exception as e: raise RuntimeError( f"Loading prompt_adapter " From 503adf4247431c5023b9d6c02052b1c5dcf042ab Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Wed, 3 Jul 2024 17:03:53 -0400 Subject: [PATCH 65/80] updating tests --- tests/prompt_adapter/test_bloom.py | 3 ++- tests/prompt_adapter/test_multi_adapter_inference.py | 3 ++- tests/prompt_adapter/test_pa_lora.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index 7c13a81b6f2cb..d0ae194d7a1b4 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -32,7 +32,8 @@ def do_sample(llm, pa_name: str, pa_id: int): def test_twitter_prompt_adapter(): - llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True) + llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True, + max_prompt_adapter_token= 8) expected_output = ['complaint', 'no complaint'] diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py index 0cc8c8bc50fd0..dab56918756f3 100644 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -44,7 +44,8 @@ def do_sample(engine): def test_multi_prompt_adapters(): engine_args = EngineArgs(model=MODEL_PATH, max_prompt_adapters=3, - enable_prompt_adapter=True) + enable_prompt_adapter=True, + max_prompt_adapter_token= 8) engine = LLMEngine.from_engine_args(engine_args) expected_output = { ' quot;I', 'hate speech', 'no complaint', 'not hate speech' diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py index 89f349fec6337..650f60f2289d0 100644 --- a/tests/prompt_adapter/test_pa_lora.py +++ b/tests/prompt_adapter/test_pa_lora.py @@ -50,7 +50,8 @@ def test_lora_prompt_adapter(): engine_args = EngineArgs(model=MODEL_PATH, enable_prompt_adapter=True, enable_lora=True, - max_num_seqs=60) + max_num_seqs=60, + max_prompt_adapter_token= 8) engine = LLMEngine.from_engine_args(engine_args) result = do_sample(engine) From 45c12ee714d4a09981b3dab5f1c68beb90c6b816 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Fri, 5 Jul 2024 14:59:43 -0400 Subject: [PATCH 66/80] fix eager issue --- tests/prompt_adapter/test_bloom.py | 6 ++++-- tests/prompt_adapter/test_multi_adapter_inference.py | 2 +- tests/prompt_adapter/test_pa_lora.py | 2 +- vllm/prompt_adapter/models.py | 2 +- vllm/worker/model_runner.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index d0ae194d7a1b4..7d736e187aef6 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -32,8 +32,10 @@ def do_sample(llm, pa_name: str, pa_id: int): def test_twitter_prompt_adapter(): - llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True, - max_prompt_adapter_token= 8) + llm = vllm.LLM(MODEL_PATH, + enforce_eager=True, + enable_prompt_adapter=True, + max_prompt_adapter_token=8) expected_output = ['complaint', 'no complaint'] diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py index dab56918756f3..39a79becdfbb3 100644 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -45,7 +45,7 @@ def test_multi_prompt_adapters(): engine_args = EngineArgs(model=MODEL_PATH, max_prompt_adapters=3, enable_prompt_adapter=True, - max_prompt_adapter_token= 8) + max_prompt_adapter_token=8) engine = LLMEngine.from_engine_args(engine_args) expected_output = { ' quot;I', 'hate speech', 'no complaint', 'not hate speech' diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py index 650f60f2289d0..2a5f23f7f92ec 100644 --- a/tests/prompt_adapter/test_pa_lora.py +++ b/tests/prompt_adapter/test_pa_lora.py @@ -51,7 +51,7 @@ def test_lora_prompt_adapter(): enable_prompt_adapter=True, enable_lora=True, max_num_seqs=60, - max_prompt_adapter_token= 8) + max_prompt_adapter_token=8) engine = LLMEngine.from_engine_args(engine_args) result = do_sample(engine) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index d1c7a4fbb3415..6e9d74fb52b24 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -128,7 +128,7 @@ def __init__( self.adapter_type = 'PromptAdapter' self.base_indices = torch.tensor([-1]) - self.base_embedding_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([]) self.modules: Dict[str, nn.Module] = {} self._create_prompt_adapter_modules() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 174d5818e4f05..b57ca31ebcffd 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -555,7 +555,7 @@ def _prepare_model_input_tensors( mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) - if prompt_adapter_id > 0: + if prompt_adapter_id > 0 and is_prompt: prompt_adapter_requests.add( seq_group_metadata.prompt_adapter_request) From 13d42c685d84f54647886a958dc8f475f5ffdc7c Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Fri, 5 Jul 2024 15:25:47 -0400 Subject: [PATCH 67/80] formatting --- vllm/worker/embedding_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d9adb5a2e24fd..a333e6634a41f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -5,7 +5,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams @@ -52,7 +52,7 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config) + prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config) @torch.inference_mode() From b2f3842f7982ee541d73cab8bbd2d81287c51118 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Fri, 5 Jul 2024 15:32:17 -0400 Subject: [PATCH 68/80] formatting --- vllm/executor/executor_base.py | 3 ++- vllm/executor/ray_xpu_executor.py | 3 ++- vllm/executor/xpu_executor.py | 3 ++- vllm/worker/worker.py | 3 ++- vllm/worker/xpu_worker.py | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 0b4d3917abe34..6c8593c245396 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,7 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, PromptAdapterConfig) + SchedulerConfig, SpeculativeConfig, + PromptAdapterConfig) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index f8b7208a6b1bb..c74b0be780f93 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, PromptAdapterConfig) + SchedulerConfig, SpeculativeConfig, + PromptAdapterConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 1202b1219f7f3..ebc6b89164c79 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -4,7 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, PromptAdapterConfig) + SchedulerConfig, SpeculativeConfig, + PromptAdapterConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 571b3a9cae50a..796133dc4bbab 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, PromptAdapterConfig) + SchedulerConfig, SpeculativeConfig, + PromptAdapterConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 53f2eb6390124..a00d2b86c7f2a 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -10,7 +10,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, PromptAdapterConfig) + SchedulerConfig, SpeculativeConfig, + PromptAdapterConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger From 191f2c92570f72c6791db07ece2ab1d204c38380 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Fri, 5 Jul 2024 22:25:21 -0400 Subject: [PATCH 69/80] replacing numel w ndim for LoRA consistency --- vllm/prompt_adapter/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index 07aa015d82572..d1ca4143b9dab 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -69,7 +69,7 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) - if self.embedding_indices_gpu.numel(): + if self.embedding_indices_gpu.ndim>1: valid_mask = self.indices_gpu != -1 gathered_embeddings = self.embeddings_tensors[ self.embedding_indices_gpu[:, 0], From 50514c3a94fc4d2cbd2fd9964311b6fb6f196ab0 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 15:05:00 -0400 Subject: [PATCH 70/80] Update tests/prompt_adapter/test_bloom.py Co-authored-by: Antoni Baum --- tests/prompt_adapter/test_bloom.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index 7d736e187aef6..980a40e7e88c6 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -31,9 +31,10 @@ def do_sample(llm, pa_name: str, pa_id: int): return generated_texts -def test_twitter_prompt_adapter(): +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_twitter_prompt_adapter(enforce_eager: bool): llm = vllm.LLM(MODEL_PATH, - enforce_eager=True, + enforce_eager=enforce_eager, enable_prompt_adapter=True, max_prompt_adapter_token=8) From 1217964890e693369cf8c5f89e734d18b5ef5f6d Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 15:06:58 -0400 Subject: [PATCH 71/80] Update vllm/prompt_adapter/models.py Co-authored-by: Antoni Baum --- vllm/prompt_adapter/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 6e9d74fb52b24..e80afadec7b6b 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -60,12 +60,13 @@ def convert_mapping( for idx, id_ in enumerate(prompt_adapter_index_to_id) if id_ is not None } - pa_indices = torch.tensor([ + pa_indices = ([ id_to_index.get(id_, -1) if id_ > 0 else -1 for id_ in mapping.index_mapping ]) pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + pa_indices = torch.tensor(pa_indices) return pa_indices, pa_embedding_mapping From f9a5b4a77c71de8350f3a00e1e111f1c1e4ce9d9 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 15:08:39 -0400 Subject: [PATCH 72/80] formatting --- tests/prompt_adapter/test_bloom.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py index 980a40e7e88c6..6528b3009b8c0 100644 --- a/tests/prompt_adapter/test_bloom.py +++ b/tests/prompt_adapter/test_bloom.py @@ -1,3 +1,5 @@ +import pytest + import vllm from vllm.prompt_adapter.request import PromptAdapterRequest From 8545205ccc112515366a0f043149a9bc7481d049 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 15:10:34 -0400 Subject: [PATCH 73/80] formatting --- vllm/engine/arg_utils.py | 4 ++-- vllm/engine/llm_engine.py | 5 +++-- vllm/executor/executor_base.py | 4 ++-- vllm/executor/ray_xpu_executor.py | 4 ++-- vllm/executor/xpu_executor.py | 4 ++-- vllm/spec_decode/draft_model_runner.py | 2 +- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/cpu_worker.py | 2 +- vllm/worker/model_runner.py | 2 +- vllm/worker/worker.py | 4 ++-- vllm/worker/xpu_model_runner.py | 2 +- vllm/worker/xpu_worker.py | 4 ++-- 12 files changed, 20 insertions(+), 19 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 00f9335a58f79..b972573c0258e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,8 +7,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e8a9510c41fc2..1d7f71ee49c19 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,8 +8,9 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, PromptAdapterConfig) + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 6c8593c245396..6f9e554459161 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,8 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index c74b0be780f93..33f9321b5ff36 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -8,8 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index ebc6b89164c79..f6550cce9ab1a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -4,8 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 635f80d24bc13..90bba96ee8acb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -4,7 +4,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 74c9243c1e07c..db0e178e45f4e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index f192154a772c2..3c22c73267b7f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,7 +7,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3d9b5b17c7f3e..06b9c90cf65ce 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,7 +25,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 796133dc4bbab..857cd86beff92 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,8 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 8910e58493b59..e03f24fdfc41a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index a00d2b86c7f2a..6a822c2ba3e7a 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -10,8 +10,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - PromptAdapterConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger From 2d5c24629950cb4fa0e3a46db37913ecf5ae78cf Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 17:48:14 -0400 Subject: [PATCH 74/80] formatting --- vllm/prompt_adapter/layers.py | 4 ++-- vllm/prompt_adapter/models.py | 8 +++++++- vllm/worker/model_runner.py | 17 ++++++++--------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py index d1ca4143b9dab..27a61e692e1b7 100644 --- a/vllm/prompt_adapter/layers.py +++ b/vllm/prompt_adapter/layers.py @@ -69,7 +69,7 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) - if self.embedding_indices_gpu.ndim>1: + if self.embedding_indices_gpu.ndim > 1: valid_mask = self.indices_gpu != -1 gathered_embeddings = self.embeddings_tensors[ self.embedding_indices_gpu[:, 0], @@ -77,4 +77,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Update hidden states hidden_states[valid_mask] = gathered_embeddings - return hidden_states + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index e80afadec7b6b..93eb3bde646ac 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -12,7 +12,8 @@ remove_adapter, set_adapter_mapping) from vllm.config import PromptAdapterConfig from vllm.prompt_adapter.layers import ( - PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) + VocabParallelEmbeddingWithPromptAdapter) # yapf: disable +from vllm.prompt_adapter.layers import PromptAdapterMapping logger = logging.getLogger(__name__) @@ -117,6 +118,11 @@ def __init__( Args: model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + prompt_adapter_config: the PromptAdapter config, """ self.model: nn.Module = model # Dict instead of a Set for compatibility with LRUCache. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 06b9c90cf65ce..205b4f58f7a83 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -253,15 +253,14 @@ def __init__( def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", From 3da2777a3cbdc0d81beec734d3ef83a221a59737 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 18:12:17 -0400 Subject: [PATCH 75/80] docs update --- vllm/engine/llm_engine.py | 2 ++ vllm/entrypoints/llm.py | 4 ++++ vllm/sequence.py | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1d7f71ee49c19..b476594fc73f6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -95,6 +95,8 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e2c02c9db6c60..57e81a6317725 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -273,6 +273,8 @@ def generate( prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `RequestOutput` objects containing the @@ -415,6 +417,8 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `EmbeddingRequestOutput` objects containing the diff --git a/vllm/sequence.py b/vllm/sequence.py index f6d271e44d2d9..a3f998b94d795 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -239,6 +239,8 @@ class Sequence: block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. + prompt_adapter_request: Prompt Adapter request. + """ def __init__( @@ -422,6 +424,7 @@ class SequenceGroup: encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -644,6 +647,7 @@ class SequenceGroupMetadata: (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. + prompt_adapter_request: Prompt Adapter request. """ def __init__( From 8279496b4c2f7ef7b8723fd7ee2a2805cb1bbeba Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 21:58:03 -0400 Subject: [PATCH 76/80] formatting --- vllm/entrypoints/openai/serving_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a646f3ac97a2a..9ba5450891e8b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -75,7 +75,8 @@ def __init__( self.prompt_adapter_requests = [] if prompt_adapters is not None: for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with open(f"./{prompt_adapter.local_path}/adapter_config.json") as f: + with open(f"./{prompt_adapter.local_path}" + f"/adapter_config.json") as f: adapter_config = json.load(f) num_virtual_tokens = adapter_config["num_virtual_tokens"] self.prompt_adapter_requests.append( @@ -100,6 +101,8 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] + model_cards.extend(lora_cards) + prompt_adapter_cards = [ ModelCard(id=prompt_adapter.prompt_adapter_name, root=self.served_model_names[0], @@ -107,7 +110,6 @@ async def show_available_models(self) -> ModelList: for prompt_adapter in self.prompt_adapter_requests ] model_cards.extend(prompt_adapter_cards) - model_cards.extend(lora_cards) return ModelList(data=model_cards) def create_error_response( From 4336df19bca190c24b7dd9de33e49f134bf08ee4 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 22:05:01 -0400 Subject: [PATCH 77/80] formatting --- vllm/entrypoints/openai/serving_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9ba5450891e8b..017e4e05024b1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -101,14 +101,13 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] - model_cards.extend(lora_cards) - prompt_adapter_cards = [ ModelCard(id=prompt_adapter.prompt_adapter_name, root=self.served_model_names[0], permission=[ModelPermission()]) for prompt_adapter in self.prompt_adapter_requests ] + model_cards.extend(lora_cards) model_cards.extend(prompt_adapter_cards) return ModelList(data=model_cards) From 77183d777fb0b9ab7c9a23b5c85aa3bd6453b71f Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 8 Jul 2024 23:41:09 -0400 Subject: [PATCH 78/80] quick openapi fix --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 8 +++++-- vllm/entrypoints/openai/serving_engine.py | 24 ++++++------------- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 415bdbbd7c455..010d6f2ebb909 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -258,7 +258,7 @@ async def create_chat_completion( prompt=prompt, add_special_tokens=request.add_special_tokens) sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) + _, lora_request = self._maybe_get_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 592041d474ff5..4cbea5bbdc4ae 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -104,8 +104,12 @@ async def create_completion(self, request: CompletionRequest, generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) - prompt_adapter_request = self._maybe_get_prompt_adapter(request) + adapter_type, adapter_request = self._maybe_get_adapter(request) + lora_request, prompt_adapter_request = None, None + if adapter_type == 'LoRA': + lora_request, prompt_adapter_request = adapter_request, None + elif adapter_type == 'PromptAdapter': + lora_request, prompt_adapter_request = None, adapter_request decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 017e4e05024b1..f32475ddc6c61 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -152,31 +152,21 @@ async def _check_model( err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora( + def _maybe_get_adapter( self, request: Union[CompletionRequest, ChatCompletionRequest, EmbeddingRequest] - ) -> Optional[LoRARequest]: + ) -> Tuple[Optional[str], Optional[Union[LoRARequest, + PromptAdapterRequest]]]: if request.model in self.served_model_names: - return None + return None, None for lora in self.lora_requests: if request.model == lora.lora_name: - return lora - return None - # if _check_model has been called earlier, this will be unreachable - #raise ValueError(f"The model `{request.model}` does not exist.") - - def _maybe_get_prompt_adapter( - self, request: Union[CompletionRequest, ChatCompletionRequest, - EmbeddingRequest] - ) -> Optional[PromptAdapterRequest]: - if request.model in self.served_model_names: - return None + return 'LoRA', lora for prompt_adapter in self.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: - return prompt_adapter - return None + return 'PromptAdapter', prompt_adapter # if _check_model has been called earlier, this will be unreachable - #raise ValueError(f"The model `{request.model}` does not exist.") + raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self, From dd887f8bd27491c386d9c3b35838163afffaa1d5 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 9 Jul 2024 06:32:26 -0400 Subject: [PATCH 79/80] formatting --- vllm/entrypoints/openai/serving_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f32475ddc6c61..9416ec90cd72d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -61,9 +61,9 @@ def __init__( self.served_model_names = served_model_names - if lora_modules is None: - self.lora_requests = [] - else: + + self.lora_requests = [] + if lora_modules is not None: self.lora_requests = [ LoRARequest( lora_name=lora.name, From 67a9f17d14b683529fbe5beee23ab081822c7d96 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 9 Jul 2024 06:49:46 -0400 Subject: [PATCH 80/80] formatting --- vllm/entrypoints/openai/serving_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9416ec90cd72d..58e6571d310e6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -61,7 +61,6 @@ def __init__( self.served_model_names = served_model_names - self.lora_requests = [] if lora_modules is not None: self.lora_requests = [