Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Adding support for insertion of soft-tuned prompts #4645

Merged
merged 91 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
04f262e
soft prompt support
Jun 3, 2024
96b4a1a
Run yapf and ruff
Jun 3, 2024
3131273
Multimodal fix
Jun 3, 2024
e9ff38b
correctness update
Jun 3, 2024
9f0a8ae
formatting
Jun 3, 2024
c2937d1
formatting
Jun 3, 2024
e43e89b
reverting to hasattr
Jun 3, 2024
a2b4fc3
adapter commons fix
Jun 3, 2024
3ebee19
minor fixes
Jun 3, 2024
629a684
formatting
Jun 3, 2024
a3ad6ac
reset_adapter
Jun 3, 2024
dcd7e88
bugfix
Jun 3, 2024
647a32d
reset_adapter fix
Jun 4, 2024
90d170c
peft dependencies
Jun 4, 2024
0fca895
fixing llava bug
Jun 4, 2024
d4e531c
typing fix
Jun 4, 2024
b7f8256
async engine update
Jun 4, 2024
449d988
batchwise processing
Jun 5, 2024
f28b66e
formatting
Jun 5, 2024
220deef
formatting yapf
Jun 5, 2024
01b9bb8
formatting again
Jun 5, 2024
2ea2796
enable_adapter paramter
Jun 5, 2024
96fe5ae
formatting
Jun 5, 2024
47725d9
adding test
Jun 5, 2024
638795a
adding test
Jun 5, 2024
f7d53b3
test case update
Jun 5, 2024
16f4037
formatting
Jun 5, 2024
f2f3cbc
resetting
Jun 13, 2024
0fc0c34
formatting
Jun 13, 2024
4eb47d6
formatting
Jun 13, 2024
e69842b
formatting
Jun 13, 2024
5c17480
Fix async engine
g-eoj Jun 13, 2024
e62cbb5
Initial implementation of openai entrypoint
g-eoj Jun 13, 2024
20fc56f
Merge branch 'main' into main
SwapnilDreams100 Jun 13, 2024
612d6c5
Fixes
g-eoj Jun 13, 2024
894b9ba
async changes
Jun 18, 2024
00efe02
Merge branch 'main' into main
SwapnilDreams100 Jun 18, 2024
155ad76
formattign
Jun 18, 2024
042c9f1
formatting
Jun 18, 2024
0e46a06
adding dtype flexibility + pa lora refactor
Jun 23, 2024
3d14475
formatting
Jun 23, 2024
86e72de
formatting
Jun 23, 2024
41934cc
xpu compatibility
Jun 23, 2024
fdfec59
xpu compatibility
Jun 23, 2024
6b1f0e7
xpu compatibility
Jun 23, 2024
01bb713
xpu compatibility
Jun 23, 2024
3e5e147
Merge branch 'main' into main
SwapnilDreams100 Jun 23, 2024
d7312e2
formatting
Jun 23, 2024
454d45b
formatting + updating tests
Jun 24, 2024
409dba1
test changes
Jun 24, 2024
ab95ad7
cpu-gpu sync changes + adapter abstract changes
Jun 26, 2024
2faec61
formatting
Jun 26, 2024
f1a607c
Merge branch 'main' into main
SwapnilDreams100 Jun 26, 2024
6955301
rebase
Jun 26, 2024
2814aee
peft fix
Jun 26, 2024
0e45660
minor fix
Jun 26, 2024
d58e355
formatting
Jun 26, 2024
d700324
forward update
Jun 30, 2024
a5610a7
formatting
Jun 30, 2024
6b1c5ef
Merge branch 'main' into main
SwapnilDreams100 Jul 1, 2024
8b6e827
formatting
Jul 1, 2024
b83b6f0
spec decode fix
Jul 1, 2024
4babf0f
Merge branch 'main' into main
SwapnilDreams100 Jul 2, 2024
791ffbd
formatting
Jul 2, 2024
7226246
Merge branch 'main' into main
SwapnilDreams100 Jul 2, 2024
215947d
async executor
Jul 2, 2024
9ae47e8
formatting
Jul 2, 2024
3a2b545
formatting
Jul 2, 2024
bbaea88
formatting
Jul 2, 2024
34dbc8f
Merge branch 'main' into openai-entrypoint
g-eoj Jul 3, 2024
9c2cc27
Merge branch 'main' into main
SwapnilDreams100 Jul 3, 2024
cdcea67
formatting
Jul 3, 2024
e771d43
max_prompt_adapter_token defaults + error messages
Jul 3, 2024
503adf4
updating tests
Jul 3, 2024
45c12ee
fix eager issue
Jul 5, 2024
9a73128
Merge branch 'main' into main
SwapnilDreams100 Jul 5, 2024
13d42c6
formatting
Jul 5, 2024
b2f3842
formatting
Jul 5, 2024
191f2c9
replacing numel w ndim for LoRA consistency
Jul 6, 2024
50514c3
Update tests/prompt_adapter/test_bloom.py
SwapnilDreams100 Jul 8, 2024
1217964
Update vllm/prompt_adapter/models.py
SwapnilDreams100 Jul 8, 2024
f9a5b4a
formatting
Jul 8, 2024
8545205
formatting
Jul 8, 2024
2d5c246
formatting
Jul 8, 2024
3da2777
docs update
Jul 8, 2024
9634b9d
Merge pull request #2 from g-eoj/openai-entrypoint
SwapnilDreams100 Jul 9, 2024
8279496
formatting
Jul 9, 2024
4336df1
formatting
Jul 9, 2024
77183d7
quick openapi fix
Jul 9, 2024
dd887f8
formatting
Jul 9, 2024
67a9f17
formatting
Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy tests --config-file pyproject.toml


Expand Down
9 changes: 4 additions & 5 deletions tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ def batched_generate(
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None)

outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
Expand Down
326 changes: 164 additions & 162 deletions tests/lora/test_lora_manager.py

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions tests/prompt_adapter/test_bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'


def do_sample(llm, pa_name: str, pa_id: int):

prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])

outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)

# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)

expected_output = ['complaint', 'no complaint']

assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
53 changes: 53 additions & 0 deletions tests/prompt_adapter/test_multi_adapter_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'


def do_sample(engine):

prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]

request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1

request_outputs = engine.step()

for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results


def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
61 changes: 61 additions & 0 deletions tests/prompt_adapter/test_pa_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


def do_sample(engine):

prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501

# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]

request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1

request_outputs = engine.step()

for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results


def test_lora_prompt_adapter():
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)

expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
2 changes: 2 additions & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -92,6 +93,7 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalDataDict] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]:

if prompts is None:
Expand Down
1 change: 1 addition & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions vllm/adapter_commons/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Tuple


@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
104 changes: 104 additions & 0 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar

from torch import nn

from vllm.logger import init_logger
from vllm.utils import LRUCache

logger = init_logger(__name__)


class AdapterModel(ABC):

def __init__(self, model_id=None):
self.id = model_id

@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")


T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)


class AdapterModelManager(ABC):

def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None

def __len__(self) -> int:
return len(self._registered_adapters)

@property
@abstractmethod
def adapter_slots(self):
...

@property
@abstractmethod
def capacity(self):
...

@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...

@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def remove_all_adapters(self):
...

@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...

@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...

@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
25 changes: 25 additions & 0 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import abstractmethod
from dataclasses import dataclass


@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""

SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved
@property
@abstractmethod
def adapter_id(self):
...

def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")

def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id

def __hash__(self) -> int:
return hash(self.adapter_id)
Loading
Loading