Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Adding support for insertion of soft-tuned prompts #4645

Merged
merged 91 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
04f262e
soft prompt support
Jun 3, 2024
96b4a1a
Run yapf and ruff
Jun 3, 2024
3131273
Multimodal fix
Jun 3, 2024
e9ff38b
correctness update
Jun 3, 2024
9f0a8ae
formatting
Jun 3, 2024
c2937d1
formatting
Jun 3, 2024
e43e89b
reverting to hasattr
Jun 3, 2024
a2b4fc3
adapter commons fix
Jun 3, 2024
3ebee19
minor fixes
Jun 3, 2024
629a684
formatting
Jun 3, 2024
a3ad6ac
reset_adapter
Jun 3, 2024
dcd7e88
bugfix
Jun 3, 2024
647a32d
reset_adapter fix
Jun 4, 2024
90d170c
peft dependencies
Jun 4, 2024
0fca895
fixing llava bug
Jun 4, 2024
d4e531c
typing fix
Jun 4, 2024
b7f8256
async engine update
Jun 4, 2024
449d988
batchwise processing
Jun 5, 2024
f28b66e
formatting
Jun 5, 2024
220deef
formatting yapf
Jun 5, 2024
01b9bb8
formatting again
Jun 5, 2024
2ea2796
enable_adapter paramter
Jun 5, 2024
96fe5ae
formatting
Jun 5, 2024
47725d9
adding test
Jun 5, 2024
638795a
adding test
Jun 5, 2024
f7d53b3
test case update
Jun 5, 2024
16f4037
formatting
Jun 5, 2024
f2f3cbc
resetting
Jun 13, 2024
0fc0c34
formatting
Jun 13, 2024
4eb47d6
formatting
Jun 13, 2024
e69842b
formatting
Jun 13, 2024
5c17480
Fix async engine
g-eoj Jun 13, 2024
e62cbb5
Initial implementation of openai entrypoint
g-eoj Jun 13, 2024
20fc56f
Merge branch 'main' into main
SwapnilDreams100 Jun 13, 2024
612d6c5
Fixes
g-eoj Jun 13, 2024
894b9ba
async changes
Jun 18, 2024
00efe02
Merge branch 'main' into main
SwapnilDreams100 Jun 18, 2024
155ad76
formattign
Jun 18, 2024
042c9f1
formatting
Jun 18, 2024
0e46a06
adding dtype flexibility + pa lora refactor
Jun 23, 2024
3d14475
formatting
Jun 23, 2024
86e72de
formatting
Jun 23, 2024
41934cc
xpu compatibility
Jun 23, 2024
fdfec59
xpu compatibility
Jun 23, 2024
6b1f0e7
xpu compatibility
Jun 23, 2024
01bb713
xpu compatibility
Jun 23, 2024
3e5e147
Merge branch 'main' into main
SwapnilDreams100 Jun 23, 2024
d7312e2
formatting
Jun 23, 2024
454d45b
formatting + updating tests
Jun 24, 2024
409dba1
test changes
Jun 24, 2024
ab95ad7
cpu-gpu sync changes + adapter abstract changes
Jun 26, 2024
2faec61
formatting
Jun 26, 2024
f1a607c
Merge branch 'main' into main
SwapnilDreams100 Jun 26, 2024
6955301
rebase
Jun 26, 2024
2814aee
peft fix
Jun 26, 2024
0e45660
minor fix
Jun 26, 2024
d58e355
formatting
Jun 26, 2024
d700324
forward update
Jun 30, 2024
a5610a7
formatting
Jun 30, 2024
6b1c5ef
Merge branch 'main' into main
SwapnilDreams100 Jul 1, 2024
8b6e827
formatting
Jul 1, 2024
b83b6f0
spec decode fix
Jul 1, 2024
4babf0f
Merge branch 'main' into main
SwapnilDreams100 Jul 2, 2024
791ffbd
formatting
Jul 2, 2024
7226246
Merge branch 'main' into main
SwapnilDreams100 Jul 2, 2024
215947d
async executor
Jul 2, 2024
9ae47e8
formatting
Jul 2, 2024
3a2b545
formatting
Jul 2, 2024
bbaea88
formatting
Jul 2, 2024
34dbc8f
Merge branch 'main' into openai-entrypoint
g-eoj Jul 3, 2024
9c2cc27
Merge branch 'main' into main
SwapnilDreams100 Jul 3, 2024
cdcea67
formatting
Jul 3, 2024
e771d43
max_prompt_adapter_token defaults + error messages
Jul 3, 2024
503adf4
updating tests
Jul 3, 2024
45c12ee
fix eager issue
Jul 5, 2024
9a73128
Merge branch 'main' into main
SwapnilDreams100 Jul 5, 2024
13d42c6
formatting
Jul 5, 2024
b2f3842
formatting
Jul 5, 2024
191f2c9
replacing numel w ndim for LoRA consistency
Jul 6, 2024
50514c3
Update tests/prompt_adapter/test_bloom.py
SwapnilDreams100 Jul 8, 2024
1217964
Update vllm/prompt_adapter/models.py
SwapnilDreams100 Jul 8, 2024
f9a5b4a
formatting
Jul 8, 2024
8545205
formatting
Jul 8, 2024
2d5c246
formatting
Jul 8, 2024
3da2777
docs update
Jul 8, 2024
9634b9d
Merge pull request #2 from g-eoj/openai-entrypoint
SwapnilDreams100 Jul 9, 2024
8279496
formatting
Jul 9, 2024
4336df1
formatting
Jul 9, 2024
77183d7
quick openapi fix
Jul 9, 2024
dd887f8
formatting
Jul 9, 2024
67a9f17
formatting
Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
batchwise processing
  • Loading branch information
Swapnil Parekh committed Jun 13, 2024
commit 449d988ad220b1abfd52334716e445d7106522a5
13 changes: 13 additions & 0 deletions vllm/adapter_commons/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass
from typing import Tuple

@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
12 changes: 3 additions & 9 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -134,15 +135,8 @@ def _apply_lora_packed_nslice(


@dataclass
class LoRAMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class LoRAMapping(AdapterMapping):
pass


class BaseLayerWithLoRA(nn.Module):
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.prompt_adapter.layers import apply_prompt_adapter
from vllm.sequence import SamplerOutput


Expand Down Expand Up @@ -279,7 +278,6 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = apply_prompt_adapter(self, hidden_states, positions)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.prompt_adapter.layers import apply_prompt_adapter
from vllm.sequence import SamplerOutput


Expand Down Expand Up @@ -252,7 +251,6 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = apply_prompt_adapter(self, hidden_states, position_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
layer = self.h[i]
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.prompt_adapter.layers import apply_prompt_adapter
from vllm.sequence import SamplerOutput


Expand Down Expand Up @@ -220,7 +219,6 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
inputs_embeds = apply_prompt_adapter(self, inputs_embeds, position_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.prompt_adapter.layers import apply_prompt_adapter
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once

Expand Down Expand Up @@ -283,7 +282,6 @@ def forward(
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = apply_prompt_adapter(self, hidden_states, positions)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.prompt_adapter.layers import apply_prompt_adapter
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once

Expand Down Expand Up @@ -463,7 +462,6 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = apply_prompt_adapter(self, hidden_states, positions)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down
69 changes: 50 additions & 19 deletions vllm/prompt_adapter/layers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from dataclasses import dataclass
from typing import Tuple
from typing import Dict, List, Optional

import numpy
import torch
from torch import nn

from vllm.adapter_commons.layers import AdapterMapping
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)


@dataclass
class PromptAdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)


def apply_prompt_adapter(instance, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
if hasattr(instance, 'prefix_encoder'):
soft_prompt = instance.prefix_encoder.prompt_embedding
indices = (positions < soft_prompt.shape[0])
hidden_states[indices] = soft_prompt[positions[indices]]
return hidden_states
class PromptAdapterMapping(AdapterMapping):
pass


class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embedding_tensors: Dict[int, torch.Tensor] = {}
self.indices: torch.Tensor

def reset_prompt_adapter(self, index: int):
self.embedding_tensors[index] = 0

def set_prompt_adapter(
self,
index: int,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if embeddings_tensor is not None:
self.embedding_tensors[index] = embeddings_tensor

def set_mapping(
self,
base_indices: List[int],
):
self.indices = base_indices
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
unique_indices = numpy.unique(self.indices)
for idx in unique_indices:
if idx != 0:
pa_idx = self.embedding_tensors[idx].prompt_embedding
mask = (self.indices == idx)
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved
try:
n_adapters = sum(mask) // pa_idx.shape[0]
hidden_states[mask] = pa_idx.repeat(n_adapters, 1)
except Exception:
pass
SwapnilDreams100 marked this conversation as resolved.
Show resolved Hide resolved
return hidden_states
136 changes: 34 additions & 102 deletions vllm/prompt_adapter/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import math
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Callable, Dict, List, Optional, Type

import torch
from peft.utils import load_peft_weights
from torch import nn

from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.layers import (
PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter)

logger = logging.getLogger(__name__)

Expand All @@ -22,69 +22,6 @@ def get_prompt_adapter_id():
return _GLOBAL_PROMPT_ADAPTER_ID


def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]], max_prompt_adapters: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""Converts PromptAdapterMapping to index tensors.

Args:
mapping: PromptAdapterMapping mapping rows in a batch to ids.
prompt_adapter_index_to_id: List mapping PromptAdapter ids to indices.
max_prompt_adapters: Maximum number of PromptAdapters.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
PromptAdapter indices for sampler. For generation, this will be
same as base_indicies. For prefill, this will map requests
to PromptAdapter indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to PromptAdapter indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_promt_adapters.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
prompt_adapter_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
prompt_adapter_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
prompt_adapter_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
prompt_adapter_idx = (prompt_adapter_index_to_id.index(
index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1)
prompt_adapter_indices[i] = prompt_adapter_idx

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, prompt_adapter_indices
]
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded ==
-1] = max_prompt_adapters - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded, indices_len)


class PromptAdapterModel(AdapterModel):

def __init__(self,
Expand Down Expand Up @@ -133,16 +70,9 @@ def __init__(
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'

self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.indices_len: List[Optional[int]] = [None] * 3
self.base_indices = [0]
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None

@property
Expand All @@ -157,15 +87,6 @@ def adapter_slots(self) -> int:
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters

def reset_adapter(self):
try:
self.remove_all_prompt_adapters()
for module_name, module in self.model.named_modules():
if 'Model' in (module.__class__.__name__):
del module.prefix_encoder
except Exception:
pass

def activate_prompt_adapter(
self,
prompt_adapter_id: int,
Expand All @@ -187,10 +108,8 @@ def activate_prompt_adapter(
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for module_name, module in self.model.named_modules():
if 'Model' in (module.__class__.__name__):
module.prefix_encoder = prompt_adapter_model
break
for _, v in self.modules.items():
v.set_prompt_adapter(prompt_adapter_id, prompt_adapter_model)
return True

@property
Expand All @@ -201,9 +120,8 @@ def _deactivate_prompt_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for module_name, module in self.model.named_modules():
if 'Model' in (module.__class__.__name__):
del module.prefix_encoder
for _, v in self.modules.items():
v.reset_prompt_adapter(prompt_adapter_id)
except ValueError:
pass

Expand Down Expand Up @@ -232,16 +150,30 @@ def remove_prompt_adapter(self):

def _set_prompt_adapter_mapping(self,
mapping: PromptAdapterMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
indices_len) = convert_mapping(mapping,
self.prompt_adapter_index_to_id,
self.prompt_adapter_slots + 1)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
# Maintain the reference
self.indices_len[:] = indices_len
for k, v in self.modules.items():
v.set_mapping(mapping.index_mapping)

def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices)

def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module

def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module

@property
def set_prompt_adapter_mapping(self):
Expand Down
3 changes: 0 additions & 3 deletions vllm/prompt_adapter/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(
def is_enabled(self) -> bool:
return True

def reset_adapter(self):
self._prompt_adapter_manager.reset_adapter()

def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
Expand Down
Loading