Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement vLLM V1 [1/N] #9289

Merged
merged 101 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
0d8a651
Add vllm_v1
WoosukKwon Oct 11, 2024
9a5c899
Max num seqs
WoosukKwon Oct 11, 2024
b2b90a9
Fix chunked prefill
WoosukKwon Oct 11, 2024
0ee05d2
Fix flash-attn
WoosukKwon Oct 11, 2024
647cb1b
yapf
WoosukKwon Oct 11, 2024
2f04e52
Fix memory
WoosukKwon Oct 13, 2024
9a159be
Fix
WoosukKwon Oct 15, 2024
00d3975
Remove time
WoosukKwon Oct 15, 2024
dff359d
Minor
WoosukKwon Oct 15, 2024
0cb2454
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 15, 2024
90390bc
Minor
WoosukKwon Oct 15, 2024
fa82f0d
Fix
WoosukKwon Oct 15, 2024
5cf508c
Revert
WoosukKwon Oct 15, 2024
8c476d5
Remove commit_id
WoosukKwon Oct 15, 2024
10e474a
Fix slot_mapping
WoosukKwon Oct 15, 2024
e35a3d2
Remove comment
WoosukKwon Oct 15, 2024
4ce3470
Fix
WoosukKwon Oct 15, 2024
815f137
Remove logits processor
WoosukKwon Oct 16, 2024
e7605a7
Fix dummy run
WoosukKwon Oct 16, 2024
ae5089b
comment
WoosukKwon Oct 16, 2024
05934ea
Fix
WoosukKwon Oct 16, 2024
fa5ad10
Remove redundancy
WoosukKwon Oct 16, 2024
ea44286
Minor
WoosukKwon Oct 16, 2024
789aeb8
Fix
WoosukKwon Oct 16, 2024
58053f0
Minor
WoosukKwon Oct 18, 2024
e56e3e5
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 18, 2024
7b3219f
Clean up
WoosukKwon Oct 18, 2024
d0090d2
Add inits
WoosukKwon Oct 18, 2024
deacb3b
yapf
WoosukKwon Oct 18, 2024
e51bda9
vllm_v1 -> vllm.v1
WoosukKwon Oct 18, 2024
9abf055
yapf
WoosukKwon Oct 18, 2024
6dd0155
Fix
WoosukKwon Oct 18, 2024
c83966b
Add VLLM_USE_V1
WoosukKwon Oct 18, 2024
bffa71c
fix
WoosukKwon Oct 18, 2024
ba1dc5e
Minor
WoosukKwon Oct 18, 2024
8a9a114
Fix
WoosukKwon Oct 18, 2024
405f895
isort
WoosukKwon Oct 18, 2024
c7a70e9
Move detokenizer_utils
WoosukKwon Oct 18, 2024
2bec533
Minor
WoosukKwon Oct 18, 2024
f4f573b
yapf
WoosukKwon Oct 18, 2024
2a29e1d
Fix
WoosukKwon Oct 18, 2024
dc2106f
Minor
WoosukKwon Oct 18, 2024
68bd6f7
Rename ports
WoosukKwon Oct 18, 2024
f03d574
Comment
WoosukKwon Oct 18, 2024
44b152b
comment
WoosukKwon Oct 18, 2024
1b186a8
Comment
WoosukKwon Oct 18, 2024
4e07a47
Minor
WoosukKwon Oct 18, 2024
c6ab902
Add comments
WoosukKwon Oct 18, 2024
fd59c5e
Remove unused methods
WoosukKwon Oct 18, 2024
4afd2d2
Add check_health
WoosukKwon Oct 18, 2024
6cea5e7
Fix switching between V0 and V1 engine
WoosukKwon Oct 18, 2024
9978be3
Make async detokenizer work
WoosukKwon Oct 18, 2024
0e93601
yapf
WoosukKwon Oct 18, 2024
b96dd05
Do not send prompt tokens redundantly
WoosukKwon Oct 19, 2024
248f890
Remove async gpu executor
WoosukKwon Oct 19, 2024
6225c8d
compatibility
WoosukKwon Oct 19, 2024
f7752d8
Optimize random_sample
WoosukKwon Oct 20, 2024
c460da9
Remove
WoosukKwon Oct 20, 2024
b4a674b
Use dict
WoosukKwon Oct 20, 2024
0c5f5a9
yapf
WoosukKwon Oct 20, 2024
40b4c78
Fix
WoosukKwon Oct 20, 2024
b2aaea2
Minor
WoosukKwon Oct 20, 2024
d5ec4cb
Fix deotkenizer
WoosukKwon Oct 20, 2024
fbbb771
Minor
WoosukKwon Oct 20, 2024
ad3b0d9
Detokenizer & DetokenizerProc
WoosukKwon Oct 20, 2024
91ae792
yapf
WoosukKwon Oct 20, 2024
9598d43
Minor
WoosukKwon Oct 20, 2024
40c5114
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 20, 2024
0ef47d5
Fix
WoosukKwon Oct 20, 2024
96e5781
Fix
WoosukKwon Oct 20, 2024
aefa95f
Optimize
WoosukKwon Oct 20, 2024
952fab8
Comment
WoosukKwon Oct 20, 2024
eb2008a
Fix
WoosukKwon Oct 20, 2024
ec8e871
Add comment on scheduler
WoosukKwon Oct 20, 2024
f811fe0
Optimize object creation
WoosukKwon Oct 20, 2024
f03416b
Optimize finish_requests
WoosukKwon Oct 20, 2024
cd57404
Minor:
WoosukKwon Oct 20, 2024
9f637d6
Minor
WoosukKwon Oct 20, 2024
8ac308f
Support API server
WoosukKwon Oct 20, 2024
d35fb71
Fix
WoosukKwon Oct 20, 2024
3af10d7
Minor
WoosukKwon Oct 20, 2024
da2958f
Support stop ids
WoosukKwon Oct 21, 2024
f89edac
Minor
WoosukKwon Oct 21, 2024
864dd27
RequestMetrics
WoosukKwon Oct 21, 2024
f8f7d23
Fix
WoosukKwon Oct 21, 2024
e5fb326
mypy
WoosukKwon Oct 21, 2024
380568c
mypy
WoosukKwon Oct 21, 2024
ec43110
mypy
WoosukKwon Oct 21, 2024
cd99b21
Fix
WoosukKwon Oct 21, 2024
76bb54f
TODO on top-p top-k
WoosukKwon Oct 21, 2024
261a1ef
Refactor
WoosukKwon Oct 21, 2024
5a2ddbf
Minor
WoosukKwon Oct 21, 2024
44412b5
Remove
WoosukKwon Oct 21, 2024
f8c8b8e
typo
WoosukKwon Oct 21, 2024
a0fa8eb
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 21, 2024
3ba8865
Preallocate instead of watermark
WoosukKwon Oct 21, 2024
8c4b84c
RequestOutput
WoosukKwon Oct 21, 2024
0d21798
Minor
WoosukKwon Oct 21, 2024
c13b503
num_new_tokens
WoosukKwon Oct 21, 2024
804f0cd
Minor
WoosukKwon Oct 21, 2024
e441f0a
Add __init__
WoosukKwon Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
Expand Down Expand Up @@ -110,6 +111,10 @@ def get_attn_backend(
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand Down Expand Up @@ -215,6 +220,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1

# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):
Expand Down
27 changes: 17 additions & 10 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cloudpickle
import zmq

from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
Expand All @@ -21,12 +21,17 @@
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext

if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine

CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]

Expand Down Expand Up @@ -136,14 +141,16 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,

executor_class = LLMEngine._get_executor_cls(engine_config)

return cls(
ipc_path=ipc_path,
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
use_async_sockets = (engine_config.model_config.use_async_output_proc
and not VLLM_USE_V1)

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)

def start(self):
try:
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from tqdm import tqdm

from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
Expand All @@ -31,6 +31,11 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of

if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore

logger = init_logger(__name__)


Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -450,6 +451,10 @@ def get_default_config_root():
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),

# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
}

# end-env-vars-definition
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
if sampling_metadata is not None:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
Expand All @@ -69,7 +70,8 @@ def forward(
logits *= self.scale

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)

return logits

Expand Down
168 changes: 3 additions & 165 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional

from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup)

from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup

Expand Down Expand Up @@ -161,167 +163,3 @@ def decode_sequence_inplace(self, seq: Sequence,
seq.output_text += new_decoded_token_text

return len(new_decoded_token_text)


def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""


def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)


# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5


def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.

Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset


# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.

If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.

This function will also return the new prefix offset and the new read
offset to be used in the next iteration.

The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.

Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None

# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens

# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens

# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)

if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset

new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
Loading