forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CORE] Adding support for insertion of soft-tuned prompts (vllm-proje…
…ct#4645) Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com> (cherry picked from commit 4d6ada9)
- Loading branch information
1 parent
b79fdd6
commit bb626df
Showing
48 changed files
with
1,951 additions
and
518 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
|
||
import vllm | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
||
MODEL_PATH = "bigscience/bloomz-560m" | ||
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' | ||
|
||
|
||
def do_sample(llm, pa_name: str, pa_id: int): | ||
|
||
prompts = [ | ||
"Tweet text : @nationalgridus I have no water and the bill is \ | ||
current and paid. Can you do something about this? Label : ", | ||
"Tweet text : @nationalgridus Looks good thanks! Label : " | ||
] | ||
sampling_params = vllm.SamplingParams(temperature=0.0, | ||
max_tokens=3, | ||
stop_token_ids=[3]) | ||
|
||
outputs = llm.generate(prompts, | ||
sampling_params, | ||
prompt_adapter_request=PromptAdapterRequest( | ||
pa_name, pa_id, PA_PATH, 8) if pa_id else None) | ||
|
||
# Print the outputs. | ||
generated_texts = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text.strip() | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
return generated_texts | ||
|
||
|
||
@pytest.mark.parametrize("enforce_eager", [True, False]) | ||
def test_twitter_prompt_adapter(enforce_eager: bool): | ||
llm = vllm.LLM(MODEL_PATH, | ||
enforce_eager=enforce_eager, | ||
enable_prompt_adapter=True, | ||
max_prompt_adapter_token=8) | ||
|
||
expected_output = ['complaint', 'no complaint'] | ||
|
||
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from vllm import EngineArgs, LLMEngine, SamplingParams | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
||
MODEL_PATH = "bigscience/bloomz-560m" | ||
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' | ||
pa_path2 = 'swapnilbp/angry_tweet_ptune' | ||
|
||
|
||
def do_sample(engine): | ||
|
||
prompts = [ | ||
("Tweet text: I have complaints! Label: ", | ||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), | ||
("Tweet text: I have no problems Label: ", | ||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), | ||
("Tweet text: I have complaints! Label: ", | ||
SamplingParams(temperature=0.0, max_tokens=3), None), | ||
("Tweet text: I have no problems Label: ", | ||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
PromptAdapterRequest("complain", 3, pa_path, 8)), | ||
] | ||
|
||
request_id = 0 | ||
results = set() | ||
while prompts or engine.has_unfinished_requests(): | ||
if prompts: | ||
prompt, sampling_params, pa_request = prompts.pop(0) | ||
engine.add_request(str(request_id), | ||
prompt, | ||
sampling_params, | ||
prompt_adapter_request=pa_request) | ||
request_id += 1 | ||
|
||
request_outputs = engine.step() | ||
|
||
for request_output in request_outputs: | ||
if request_output.finished: | ||
results.add(request_output.outputs[0].text) | ||
return results | ||
|
||
|
||
def test_multi_prompt_adapters(): | ||
engine_args = EngineArgs(model=MODEL_PATH, | ||
max_prompt_adapters=3, | ||
enable_prompt_adapter=True, | ||
max_prompt_adapter_token=8) | ||
engine = LLMEngine.from_engine_args(engine_args) | ||
expected_output = { | ||
' quot;I', 'hate speech', 'no complaint', 'not hate speech' | ||
} | ||
assert do_sample(engine) == expected_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from huggingface_hub import snapshot_download | ||
|
||
from vllm import EngineArgs, LLMEngine, SamplingParams | ||
from vllm.lora.request import LoRARequest | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
||
MODEL_PATH = "meta-llama/Llama-2-7b-hf" | ||
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") | ||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") | ||
|
||
|
||
def do_sample(engine): | ||
|
||
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 | ||
|
||
# first prompt with a prompt adapter and second without adapter | ||
prompts = [ | ||
(prompt_text, | ||
SamplingParams(temperature=0.0, max_tokens=100, | ||
stop=["[/assistant]"]), | ||
PromptAdapterRequest("hate_speech", 1, pa_path, | ||
8), LoRARequest("sql_test", 1, lora_path)), | ||
(prompt_text, | ||
SamplingParams(temperature=0.0, max_tokens=100, | ||
stop=["[/assistant]"]), None, | ||
LoRARequest("sql_test", 1, lora_path)), | ||
] | ||
|
||
request_id = 0 | ||
results = set() | ||
while prompts or engine.has_unfinished_requests(): | ||
if prompts: | ||
prompt, sampling_params, pa_request, lora_request = prompts.pop(0) | ||
engine.add_request(str(request_id), | ||
prompt, | ||
sampling_params, | ||
prompt_adapter_request=pa_request, | ||
lora_request=lora_request) | ||
request_id += 1 | ||
|
||
request_outputs = engine.step() | ||
|
||
for request_output in request_outputs: | ||
if request_output.finished: | ||
results.add(request_output.outputs[0].text) | ||
return results | ||
|
||
|
||
def test_lora_prompt_adapter(): | ||
engine_args = EngineArgs(model=MODEL_PATH, | ||
enable_prompt_adapter=True, | ||
enable_lora=True, | ||
max_num_seqs=60, | ||
max_prompt_adapter_token=8) | ||
engine = LLMEngine.from_engine_args(engine_args) | ||
result = do_sample(engine) | ||
|
||
expected_output = { | ||
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 | ||
} | ||
assert result == expected_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
||
|
||
@dataclass | ||
class AdapterMapping: | ||
# Per every token in input_ids: | ||
index_mapping: Tuple[int, ...] | ||
# Per sampled token: | ||
prompt_mapping: Tuple[int, ...] | ||
|
||
def __post_init__(self): | ||
self.index_mapping = tuple(self.index_mapping) | ||
self.prompt_mapping = tuple(self.prompt_mapping) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar | ||
|
||
from torch import nn | ||
|
||
from vllm.logger import init_logger | ||
from vllm.utils import LRUCache | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class AdapterModel(ABC): | ||
|
||
def __init__(self, model_id=None): | ||
self.id = model_id | ||
|
||
@abstractmethod | ||
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): | ||
# Common initialization code | ||
# Load weights or embeddings from local checkpoint | ||
raise NotImplementedError("Subclasses must implement this method.") | ||
|
||
|
||
T = TypeVar('T') | ||
|
||
|
||
class AdapterLRUCache(LRUCache[T]): | ||
|
||
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], | ||
None]): | ||
super().__init__(capacity) | ||
self.deactivate_fn = deactivate_fn | ||
|
||
def _on_remove(self, key: Hashable, value: T): | ||
logger.debug("Removing adapter int id: %d", key) | ||
self.deactivate_fn(key) | ||
return super()._on_remove(key, value) | ||
|
||
|
||
class AdapterModelManager(ABC): | ||
|
||
def __init__( | ||
self, | ||
model: nn.Module, | ||
): | ||
"""Create a AdapterModelManager and adapter for a given model. | ||
Args: | ||
model: the model to be adapted. | ||
""" | ||
self.model: nn.Module = model | ||
self._registered_adapters: Dict[int, Any] = {} | ||
# Dict instead of a Set for compatibility with LRUCache. | ||
self._active_adapters: Dict[int, None] = {} | ||
self.adapter_type = 'Adapter' | ||
self._last_mapping = None | ||
|
||
def __len__(self) -> int: | ||
return len(self._registered_adapters) | ||
|
||
@property | ||
@abstractmethod | ||
def adapter_slots(self): | ||
... | ||
|
||
@property | ||
@abstractmethod | ||
def capacity(self): | ||
... | ||
|
||
@abstractmethod | ||
def activate_adapter(self, adapter_id: int) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def deactivate_adapter(self, adapter_id: int) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def add_adapter(self, adapter: Any) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def set_adapter_mapping(self, mapping: Any) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def remove_adapter(self, adapter_id: int) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def remove_all_adapters(self): | ||
... | ||
|
||
@abstractmethod | ||
def get_adapter(self, adapter_id: int) -> Optional[Any]: | ||
... | ||
|
||
@abstractmethod | ||
def list_adapters(self) -> Dict[int, Any]: | ||
... | ||
|
||
@abstractmethod | ||
def pin_adapter(self, adapter_id: int) -> bool: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class AdapterRequest: | ||
""" | ||
Base class for adapter requests. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def adapter_id(self): | ||
... | ||
|
||
def __post_init__(self): | ||
if self.adapter_id < 1: | ||
raise ValueError(f"id must be > 0, got {self.adapter_id}") | ||
|
||
def __eq__(self, value: object) -> bool: | ||
return isinstance( | ||
value, self.__class__) and self.adapter_id == value.adapter_id | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.adapter_id) |
Oops, something went wrong.