diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4ff86573e664d..a707c6330e3ef 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -17,6 +17,7 @@ class _Backend(enum.Enum): FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() @@ -110,6 +111,10 @@ def get_attn_backend( from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend + if backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend as FlashAttentionBackendV1) + return FlashAttentionBackendV1 if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -215,6 +220,9 @@ def which_attn_to_use( logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH + if envs.VLLM_USE_V1: + return _Backend.FLASH_ATTN_VLLM_V1 + # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: if not current_platform.has_device_capability(80): diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index ad0e970f36ff5..f67acdf660759 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -8,7 +8,7 @@ import cloudpickle import zmq -from vllm import AsyncEngineArgs, LLMEngine, SamplingParams +from vllm import AsyncEngineArgs, SamplingParams from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block @@ -21,12 +21,17 @@ RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1 from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +if VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine +else: + from vllm.engine.llm_engine import LLMEngine + CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, SchedulerConfig, LoRAConfig] @@ -136,14 +141,16 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, executor_class = LLMEngine._get_executor_cls(engine_config) - return cls( - ipc_path=ipc_path, - use_async_sockets=engine_config.model_config.use_async_output_proc, - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context) + use_async_sockets = (engine_config.model_config.use_async_output_proc + and not VLLM_USE_V1) + + return cls(ipc_path=ipc_path, + use_async_sockets=use_async_sockets, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) def start(self): try: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1f7893d54de68..db97fe0a0285b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,10 +6,10 @@ from tqdm import tqdm +from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs, TaskOption -from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, @@ -31,6 +31,11 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of +if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine # type: ignore +else: + from vllm.engine.llm_engine import LLMEngine # type: ignore + logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index 385db82d89249..a20271229c567 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -68,6 +68,7 @@ VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] + VLLM_USE_V1: bool = False def get_default_cache_root(): @@ -450,6 +451,10 @@ def get_default_config_root(): "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), + + # If set, use the V1 code path. + "VLLM_USE_V1": + lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 1d5b6fad2e160..288f5a1134b6b 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -48,14 +48,15 @@ def forward( self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if sampling_metadata is not None: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) @@ -69,7 +70,8 @@ def forward( logits *= self.scale # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) + if sampling_metadata is not None: + logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 2b418f3603a0b..345ea14f9f273 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,8 +1,10 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, Sequence, SequenceGroup) +from .detokenizer_utils import (convert_prompt_ids_to_tokens, + detokenize_incrementally) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup @@ -161,167 +163,3 @@ def decode_sequence_inplace(self, seq: Sequence, seq.output_text += new_decoded_token_text return len(new_decoded_token_text) - - -def _replace_none_with_empty(tokens: List[Optional[str]]): - for i, token in enumerate(tokens): - if token is None: - tokens[i] = "" - - -def _convert_tokens_to_string_with_added_encoders( - tokenizer: AnyTokenizer, - output_tokens: List[str], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, -) -> str: - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. In Python, running a for loop over a list can be slow - # even when the loop body is very simple. - sub_texts: List[str] = [] - current_sub_text: List[str] = [] - all_special_tokens = set(tokenizer.all_special_tokens) - for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: - continue - if token in tokenizer.get_added_vocab(): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - if spaces_between_special_tokens: - return " ".join(sub_texts) - else: - return "".join(sub_texts) - - -# 5 is an arbitrary value that should work for all -# tokenizers (bigger = more conservative). -INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -def convert_prompt_ids_to_tokens( - tokenizer: AnyTokenizer, - prompt_ids: List[int], - skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: - """Converts the prompt ids to tokens and returns the tokens and offsets - for incremental detokenization. - - Note that not all tokens are converted to strings. Only the tokens that - are necessary for incremental detokenization are converted to strings. - """ - # We do not need to convert the whole prompt to tokens. - # Offset a little more in case we have special tokens. - new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], - skip_special_tokens=skip_special_tokens) - read_offset = len(new_tokens) - prefix_offset = max( - read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) - # This is required to guard against out-of-vocab prompt token ids - _replace_none_with_empty(new_tokens) # type: ignore[arg-type] - return new_tokens, prefix_offset, read_offset - - -# Based on -# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 -# under Apache 2.0 license -def detokenize_incrementally( - tokenizer: AnyTokenizer, - all_input_ids: List[int], - prev_tokens: Optional[List[str]], - prefix_offset: int, - read_offset: int, - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: - """Detokenizes the input ids incrementally and returns the new tokens - and the new text. - - If `prev_tokens` is None, this function will convert the input ids to - tokens and return the tokens and the new text. Otherwise, it will return the - new tokens and the new text. - - This function will also return the new prefix offset and the new read - offset to be used in the next iteration. - - The offsets are necessary to defeat cleanup algorithms in the decode which - decide to add a space or not depending on the surrounding ids. - - Args: - tokenizer: The tokenizer to use. - all_input_ids: The input ids. The last id is the new token id. - prev_tokens: The previous tokens. If None, this function will convert - the input ids to tokens and return the tokens and the new text. - prefix_offset: The prefix offset. - read_offset: The read offset. - skip_special_tokens: Whether to skip special tokens. - spaces_between_special_tokens: Whether to add spaces between special - tokens. - """ - new_token_id = all_input_ids[-1] - # This is the first iteration for this sequence - is_first_iter = prev_tokens is None - if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) - assert prev_tokens is not None - - # If the new token id is out of bounds, return an empty string. - if 0 <= new_token_id < len(tokenizer): - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - if isinstance(new_tokens, str): - new_tokens = [new_tokens] - else: - new_tokens = [""] - output_tokens = prev_tokens + new_tokens - - # If this is the first iteration, return all tokens. - if is_first_iter: - new_tokens = output_tokens - - # The prefix text is necessary only to defeat cleanup algorithms in - # the decode which decide to add a space or not depending on the - # surrounding ids. - if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) - else: - prefix_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - new_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - if len(new_text) <= len(prefix_text) or new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - return new_tokens, "", prefix_offset, read_offset - - new_text = new_text[len(prefix_text):] - return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py new file mode 100644 index 0000000000000..37ff8a236e791 --- /dev/null +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -0,0 +1,167 @@ +from typing import List, Optional, Tuple + +from .tokenizer import AnyTokenizer + + +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: AnyTokenizer, + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts: List[str] = [] + current_sub_text: List[str] = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: AnyTokenizer, + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: AnyTokenizer, + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None + + # If the new token id is out of bounds, return an empty string. + if 0 <= new_token_id < len(tokenizer): + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + else: + new_tokens = [""] + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/v1/attention/__init__.py b/vllm/v1/attention/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/attention/backends/__init__.py b/vllm/v1/attention/backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py new file mode 100644 index 0000000000000..0530b1a6762ce --- /dev/null +++ b/vllm/v1/attention/backends/flash_attn.py @@ -0,0 +1,241 @@ +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.forward_context import get_forward_context +from vllm.vllm_flash_attn import flash_attn_varlen_func + + +class FlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash-attn-vllm-v1" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + +@dataclass +class FlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_start_loc: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + +class FlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + output = torch.ops.vllm.unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) + return output + + +@torch.library.custom_op("vllm::unified_flash_attention", + mutates_args=["kv_cache"]) +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + current_metadata = get_forward_context() + if current_metadata is None: + # Profiling run. + return torch.empty_like(query) + + assert current_metadata is not None + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + # Reshape the input keys and values and store them in the cache. + key_cache = kv_cache[0] + value_cache = kv_cache[1] + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + window_size=window_size, + block_table=attn_metadata.block_table, + softcap=logits_soft_cap, + ) + return output.view(num_tokens, hidden_size) + + +@unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/vllm/v1/core/__init__.py b/vllm/v1/core/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py new file mode 100644 index 0000000000000..9b735a8be10d7 --- /dev/null +++ b/vllm/v1/core/kv_cache_manager.py @@ -0,0 +1,108 @@ +from typing import Dict, List, Optional + +import numpy as np + +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVCacheManager: + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + sliding_window: Optional[int] = None, + enable_caching: bool = True, + num_preallocate_tokens: int = 64, + ) -> None: + self.block_size = block_size + self.num_gpu_blocks = num_gpu_blocks + self.sliding_window = sliding_window + self.enable_caching = enable_caching + # NOTE(woosuk): To avoid frequent block allocation, we preallocate some + # blocks for each request. For example, when a request reaches the end + # of its block table, we preallocate N blocks in advance. This way, we + # reduce the overhead of updating free_block_ids and ref_cnts for each + # request every step (at the cost of some memory waste). + # NOTE(woosuk): This is different from the "lookahead" slots since this + # does not guarantee that the request always has N empty blocks. After + # the request gets N empty blocks, it starts to use the blocks without + # further allocation. When it uses up all the N empty blocks, it gets + # N new empty blocks. + self.num_preallocate_tokens = num_preallocate_tokens + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + + self.free_block_ids = list(range(num_gpu_blocks)) + self.req_to_block_ids: Dict[str, List[int]] = {} + self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) + + def get_computed_blocks(self, request: Request) -> List[int]: + if not self.enable_caching: + # No prefix caching. + return [] + # TODO(woosuk): Implement hash-based caching. + return [] + + def append_slots( + self, + request: Request, + num_tokens: int, + ) -> Optional[List[int]]: + num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, + self.block_size) + req_block_ids = self.req_to_block_ids[request.request_id] + if num_required_blocks <= len(req_block_ids): + # No new block is needed. + return [] + + num_new_blocks = num_required_blocks - len(req_block_ids) + num_free_blocks = len(self.free_block_ids) + if num_new_blocks > num_free_blocks: + # Cannot allocate new blocks. + return None + + # Allocate new blocks. + num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, + num_free_blocks) + new_block_ids = self._get_new_blocks(num_new_blocks) + req_block_ids.extend(new_block_ids) + self.ref_cnts[new_block_ids] += 1 + return new_block_ids + + def allocate_slots( + self, + request: Request, + num_tokens: int, + computed_block_ids: List[int], + ) -> Optional[List[int]]: + num_required_blocks = cdiv(num_tokens, self.block_size) + num_free_blocks = len(self.free_block_ids) + if num_required_blocks > num_free_blocks: + # Cannot allocate new blocks. + return None + + num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks, + num_free_blocks) + new_block_ids = self._get_new_blocks(num_new_blocks) + block_ids = computed_block_ids + new_block_ids + self.req_to_block_ids[request.request_id] = block_ids + self.ref_cnts[block_ids] += 1 + return new_block_ids + + def free(self, request: Request) -> None: + block_ids = self.req_to_block_ids.pop(request.request_id) + self.ref_cnts[block_ids] -= 1 + for block_id in block_ids: + ref_cnt = self.ref_cnts[block_id] + if ref_cnt == 0: + self.free_block_ids.append(block_id) + + def _get_new_blocks(self, num_blocks: int) -> List[int]: + assert num_blocks <= len(self.free_block_ids) + new_block_ids = self.free_block_ids[-num_blocks:] + self.free_block_ids = self.free_block_ids[:-num_blocks] + return new_block_ids diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py new file mode 100644 index 0000000000000..41659ff62747d --- /dev/null +++ b/vllm/v1/core/scheduler.py @@ -0,0 +1,412 @@ +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import SamplingParams +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.lora_config = lora_config + # TODO: Support LoRA. + assert lora_config is None, "V1 does not support LoRA yet." + + num_gpu_blocks = cache_config.num_gpu_blocks + assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 + # Create the block space manager. + self.kv_cache_manager = KVCacheManager( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=True) + self.block_size = self.cache_config.block_size + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + + # req_id -> Request + self.requests: Dict[str, Request] = {} + # Priority queues for requests. + self.waiting: Deque[Request] = deque() + self.running: List[Request] = [] + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: Set[str] = set() + + # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + # them at each scheduling step. + # Request id -> RunningRequestData + self.running_reqs_data: Dict[str, RunningRequestData] = {} + + def schedule(self) -> "SchedulerOutput": + scheduled_new_reqs: List[Request] = [] + scheduled_resumed_reqs: List[Request] = [] + scheduled_running_reqs: List[Request] = [] + preempted_reqs: List[Request] = [] + + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and num_tokens, + # which is equal to len(prompt_token_ids) + len(output_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens. This is general enough to cover chunked prefills, + # prefix caching, and the "jump forward" optimization in the future. + + req_to_new_block_ids: Dict[str, List[int]] = {} + num_scheduled_tokens: Dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running): + if token_budget == 0: + break + + request = self.running[req_index] + num_new_tokens = request.num_tokens - request.num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + while True: + new_block_ids = self.kv_cache_manager.append_slots( + request, num_new_tokens) + if new_block_ids is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + break + else: + # The request can be scheduled. + scheduled_running_reqs.append(request) + + req_to_new_block_ids[request.request_id] = new_block_ids + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + break + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting: + if len(self.running) == self.max_num_running_reqs: + break + if token_budget == 0: + break + + request = self.waiting[0] + # Get already-cached tokens. + computed_block_ids = self.kv_cache_manager.get_computed_blocks( + request) + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_block_ids) * self.block_size + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed requests, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + new_block_ids = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_block_ids) + if new_block_ids is None: + # The request cannot be scheduled. + break + request.num_computed_tokens = num_computed_tokens + + self.waiting.popleft() + self.running.append(request) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + req_to_new_block_ids[request.request_id] = ( + computed_block_ids + new_block_ids) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) == len(self.running)) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + ResumedRequestData.from_request( + req, req_to_new_block_ids[req.request_id], + req.num_computed_tokens) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_running_request_data( + req, req_to_new_block_ids[req.request_id], + req.num_computed_tokens) for req in scheduled_running_reqs + ] + preempted_req_ids = {req.request_id for req in preempted_reqs} + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_resumed_reqs=resumed_reqs_data, + scheduled_running_reqs=running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + preempted_req_ids=preempted_req_ids, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + ) + + self.finished_req_ids = set() + return scheduler_output + + def _make_running_request_data( + self, + request: Request, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "RunningRequestData": + # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + # them at each scheduling step. + if request.request_id in self.running_reqs_data: + req_data = self.running_reqs_data[request.request_id] + req_data.new_block_ids = new_block_ids + req_data.num_computed_tokens = num_computed_tokens + else: + req_data = RunningRequestData.from_request(request, new_block_ids, + num_computed_tokens) + self.running_reqs_data[request.request_id] = req_data + return req_data + + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> List[Tuple[Request, int]]: + # NOTE(woosuk): This method doesn't consider speculative decoding. + sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist() + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + new_running: List[Request] = [] + # (request, num_sampled_tokens) + sampled: List[Tuple[Request, int]] = [] + for request in self.running: + req_id = request.request_id + request.num_computed_tokens += num_scheduled_tokens[req_id] + # When the request's num_computed_tokens catches up its num_tokens, + # the request generates output tokens. Otherwise, we ignore the + # sampler output for the request. + assert request.num_computed_tokens <= request.num_tokens + if request.num_computed_tokens == request.num_tokens: + req_index = model_runner_output.req_id_to_index[req_id] + # NOTE(woosuk): Currently, we assume that each request + # generates at most one token at each step. + token_id = sampled_token_ids[req_index] + request.output_token_ids.append(token_id) + sampled.append((request, 1)) + # TODO: Update the KV cache manager for prefix caching. + + # Check if the request is finished. + stopped = self._check_stop(request) + if stopped: + continue + + new_running.append(request) + self.running = new_running + return sampled + + def _check_stop(self, request: Request) -> bool: + if (request.num_tokens >= self.max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + self._free_request(request) + return True + + sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + self._free_request(request) + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + self._free_request(request) + return True + return False + + def add_request(self, request: Request) -> None: + self.waiting.append(request) + self.requests[request.request_id] = request + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + request_ids = set(request_ids) + + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + if request.status == RequestStatus.RUNNING: + self.running.remove(request) + else: + self.waiting.remove(request) + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> None: + assert request.is_finished() + self.kv_cache_manager.free(request) + self.running_reqs_data.pop(request.request_id, None) + del self.requests[request.request_id] + self.finished_req_ids.add(request.request_id) + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_unfinished_requests(self) -> bool: + return self.get_num_unfinished_requests() > 0 + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional[MultiModalDataDict] + sampling_params: SamplingParams + block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + block_ids: List[int], + num_computed_tokens: int, + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.inputs["prompt_token_ids"], + prompt=request.inputs.get("prompt"), + multi_modal_data=request.inputs.get("multi_modal_data"), + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class ResumedRequestData: + + req_id: str + block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + block_ids: List[int], + num_computed_tokens: int, + ) -> "ResumedRequestData": + return cls( + req_id=request.request_id, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class RunningRequestData: + + req_id: str + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: Request, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "RunningRequestData": + return cls( + req_id=request.request_id, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + scheduled_new_reqs: List[NewRequestData] + scheduled_resumed_reqs: List[ResumedRequestData] + scheduled_running_reqs: List[RunningRequestData] + + num_scheduled_tokens: Dict[str, int] + total_num_scheduled_tokens: int + + preempted_req_ids: Set[str] + finished_req_ids: Set[str] diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py new file mode 100644 index 0000000000000..511b417086c63 --- /dev/null +++ b/vllm/v1/engine/llm_engine.py @@ -0,0 +1,523 @@ +import time +from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, + Union) + +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderLLMInputs, InputRegistry, PromptType) +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import UsageContext +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.request import Request, RequestStatus +from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +class LLMEngine: + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[GPUExecutor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + # Override the configs for V1. + # FIXME + if usage_context == UsageContext.LLM_CLASS: + scheduler_config.max_num_seqs = 1024 + scheduler_config.max_num_batched_tokens = 8192 + elif usage_context == UsageContext.OPENAI_API_SERVER: + scheduler_config.max_num_seqs = 1024 + scheduler_config.max_num_batched_tokens = 2048 + + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s, mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.num_scheduler_steps, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + model_config.mm_processor_kwargs, + ) + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig( + ) + self.log_stats = log_stats + + assert not self.model_config.skip_tokenizer_init + self.tokenizer = self._init_tokenizer() + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + self.detokenizer = Detokenizer(self.model_config.tokenizer) + + self.generation_config_fields = _load_generation_config_dict( + model_config) + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) + + # Request id -> Request + self.requests: Dict[str, Request] = {} + # NOTE(woosuk): Now that the detokenizer works asynchronously, we need + # to keep track of how many steps each request has been lagged behind + # in terms of detokenization. + # Request id -> how many detokenizer steps the request should wait for. + self.num_lagged_steps: Dict[str, int] = {} + # OPTIMIZATION: Cache the request output and update it incrementally. + # This is used to avoid creating a new RequestOutput object every step. + # Request id -> RequestOutput + self.request_outputs: Dict[str, RequestOutput] = {} + + self.model_executor = executor_class( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + assert self.model_config.task != "embedding" + self._initialize_kv_caches() + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + + def _initialize_kv_caches(self) -> None: + num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( + ) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = 0 + self.model_executor.initialize_cache(num_gpu_blocks) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Create the LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + ) -> None: + assert prompt_adapter_request is None + assert trace_headers is None + self._validate_model_inputs(processed_inputs) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + # TODO(woosuk): Support embedding mode. + assert isinstance(params, SamplingParams) + sampling_params = params.clone() + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + + # TODO(woosuk): Check max_logprobs + # TODO(woosuk): Support encoder-decoder models. + req = Request(request_id, processed_inputs, params, eos_token_id, + arrival_time) + self.requests[request_id] = req + self.num_lagged_steps[request_id] = 0 + self.scheduler.add_request(req) + + def stop_remote_worker_execution_loop(self) -> None: + raise NotImplementedError("TP not implemented yet.") + + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + assert priority == 0, "vLLM V1 does not support priority at the moment." + + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + ) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + self.scheduler.finish_requests(request_id, + RequestStatus.FINISHED_ABORTED) + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return len(self.requests) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return len(self.requests) > 0 + + def step(self) -> List[RequestOutput]: + # NOTE(woosuk): This method may return an empty list when the + # detokenizer is still processing the outputs. This should not be + # considered as the end of the generation process. + # FIXME(woosuk): Currently, the step method is inefficient because it + # creates RequestOutput objects for all running requests, while they + # may not be needed unless the output is streamed to the client. + if self.scheduler.has_unfinished_requests(): + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + sampled = self.scheduler.update_from_output( + scheduler_output, output) + self.send_to_detokenizer(sampled) + req_outputs = self.recv_from_detokenizer() + return req_outputs + + def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: + inputs = DetokenizerInputs( + req_ids=[], + prompt_token_ids=[], + new_token_ids=[], + skip_special_tokens=[], + spaces_between_special_tokens=[], + free_req_ids=[], # TODO(woosuk): Implement freeing. + ) + for req, num_tokens in sampled: + inputs.req_ids.append(req.request_id) + if len(req.output_token_ids) == num_tokens: + # The request is first detokenized. + inputs.prompt_token_ids.append(req.prompt_token_ids) + else: + # The prompt token ids are already cached in the detokenizer. + inputs.prompt_token_ids.append([]) + inputs.new_token_ids.append(req.output_token_ids[-num_tokens:]) + inputs.skip_special_tokens.append( + req.sampling_params.skip_special_tokens) + inputs.spaces_between_special_tokens.append( + req.sampling_params.spaces_between_special_tokens) + + # Update the number of lagged steps. + self.num_lagged_steps[req.request_id] += 1 + self.detokenizer.send(inputs) + + def recv_from_detokenizer(self) -> List[RequestOutput]: + detokenizer_output = self.detokenizer.recv() + if detokenizer_output is None: + return [] + + req_outputs: List[RequestOutput] = [] + num_reqs = len(detokenizer_output.req_ids) + for i in range(num_reqs): + req_id = detokenizer_output.req_ids[i] + req = self.requests[req_id] + req.output_text += detokenizer_output.detokenized_texts[i] + + self.num_lagged_steps[req_id] -= 1 + finished = (self.num_lagged_steps[req_id] == 0 + and req.is_finished()) + req_output = self._make_request_output( + req, detokenizer_output.num_output_token_ids[i], + detokenizer_output.detokenized_texts[i], finished) + req_outputs.append(req_output) + + if finished: + del self.requests[req_id] + del self.num_lagged_steps[req_id] + del self.request_outputs[req_id] + return req_outputs + + def terminate_detokenizer(self) -> None: + self.detokenizer.terminate() + + def _make_request_output( + self, + request: Request, + num_output_tokens: int, + new_output_text: str, + finished: bool, + ) -> RequestOutput: + req_output = self.request_outputs.get(request.request_id) + if req_output is None: + # TODO: Support `n` > 1. + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, # TODO + finish_reason=None, + stop_reason=None, + lora_request=None, + ) + req_output = RequestOutput( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + prompt_logprobs=None, # TODO + outputs=[completion_output], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + self.request_outputs[request.request_id] = req_output + + completion_output = req_output.outputs[0] + if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE: + completion_output.text += new_output_text + completion_output.token_ids = ( + request.output_token_ids[:num_output_tokens]) + elif request.sampling_params.output_kind == RequestOutputKind.DELTA: + completion_output.text = new_output_text + num_prev_tokens = len(completion_output.token_ids) + completion_output.token_ids = request.output_token_ids[ + num_prev_tokens:num_output_tokens] + elif (request.sampling_params.output_kind == + RequestOutputKind.FINAL_ONLY): + if finished: + completion_output.text = request.output_text + completion_output.token_ids = request.output_token_ids + else: + completion_output.text = "" + completion_output.token_ids = [] + + if finished: + completion_output.finish_reason = request.get_finished_reason() + completion_output.stop_reason = request.stop_reason + req_output.finished = finished + return req_output + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, + EncoderDecoderLLMInputs]): + prompt_ids = inputs.get("prompt_token_ids") + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + @classmethod + def validate_outputs(cls, outputs, output_type): + return outputs + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + @classmethod + def _get_executor_cls(cls, engine_config: EngineConfig): + return GPUExecutor + + def is_tracing_enabled(self) -> bool: + return False + + def do_log_stats(self, *args, **kwargs) -> None: + pass + + def is_encoder_decoder_model(self) -> bool: + return False + + def start_profile(self) -> None: + pass + + def stop_profile(self) -> None: + pass + + def get_tokenizer_group(self, *args, **kwargs): + return self.tokenizer + + +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: + return {} + + return config.to_diff_dict() diff --git a/vllm/v1/executor/__init__.py b/vllm/v1/executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py new file mode 100644 index 0000000000000..c780c7031c3d6 --- /dev/null +++ b/vllm/v1/executor/gpu_executor.py @@ -0,0 +1,100 @@ +import os +from typing import Optional, Tuple + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_worker import Worker + +logger = init_logger(__name__) + + +class GPUExecutor: + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + self.worker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Worker: + """Return worker init args for a given rank.""" + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, + observability_config=self.observability_config, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# GPU blocks: %d", num_gpu_blocks) + self.worker.initialize_cache(num_gpu_blocks) + self.worker.compile_or_warm_up_model() + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + output = self.worker.execute_model(scheduler_output) + return output + + def check_health(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py new file mode 100644 index 0000000000000..8574987728844 --- /dev/null +++ b/vllm/v1/outputs.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs: Optional[torch.Tensor] + + # TODO: Support prompt logprobs. + prompt_logprob_token_ids: Optional[torch.Tensor] + prompt_logprobs: Optional[torch.Tensor] + + +@dataclass +class ModelRunnerOutput: + + # [num_reqs] + req_ids: List[str] + # req_id -> index + req_id_to_index: Dict[str, int] + + # [num_reqs] + sampled_token_ids_cpu: torch.Tensor + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids_cpu: Optional[torch.Tensor] + # [num_reqs, max_num_logprobs + 1] + logprobs_cpu: Optional[torch.Tensor] diff --git a/vllm/v1/request.py b/vllm/v1/request.py new file mode 100644 index 0000000000000..be7d4d165d280 --- /dev/null +++ b/vllm/v1/request.py @@ -0,0 +1,92 @@ +import enum +from typing import TYPE_CHECKING, List, Optional, Union + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import RequestMetrics + +if TYPE_CHECKING: + from vllm.inputs import DecoderOnlyInputs + + +class Request: + + def __init__( + self, + request_id: str, + inputs: "DecoderOnlyInputs", + sampling_params: SamplingParams, + eos_token_id: Optional[int], + arrival_time: float, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.request_id = request_id + self.inputs = inputs + self.sampling_params = sampling_params + # Because of LoRA, the eos token id can be different for each request. + self.eos_token_id = eos_token_id + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None) + self.lora_request = lora_request + + self.status = RequestStatus.WAITING + self.stop_reason: Union[int, str, None] = None + assert sampling_params.max_tokens is not None + self.max_tokens = sampling_params.max_tokens + + self.prompt = inputs.get("prompt") + self.prompt_token_ids = inputs["prompt_token_ids"] + self.num_prompt_tokens = len(self.prompt_token_ids) + self.output_token_ids: List[int] = [] + self.output_text = "" + self.num_computed_tokens = 0 + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + len(self.output_token_ids) + + @property + def num_output_tokens(self) -> int: + return len(self.output_token_ids) + + def is_finished(self) -> bool: + return RequestStatus.is_finished(self.status) + + def get_finished_reason(self) -> Union[str, None]: + return RequestStatus.get_finished_reason(self.status) + + +class RequestStatus(enum.IntEnum): + """Status of a sequence.""" + WAITING = 0 + RUNNING = 1 + PREEMPTED = 2 + # Note: anything after PREEMPTED (2) will be considered + # as a finished status. + FINISHED_STOPPED = 3 + FINISHED_LENGTH_CAPPED = 4 + FINISHED_ABORTED = 5 + FINISHED_IGNORED = 6 + + @staticmethod + def is_finished(status: "RequestStatus") -> bool: + return status > RequestStatus.PREEMPTED + + @staticmethod + def get_finished_reason(status: "RequestStatus") -> Union[str, None]: + return _FINISHED_REASON_MAP.get(status) + + +# Mapping of finished statuses to their finish reasons. +# NOTE: The ignored sequences are the sequences whose prompt lengths +# are longer than the model's length cap. Therefore, the stop +# reason should also be "length" as in OpenAI API. +_FINISHED_REASON_MAP = { + RequestStatus.FINISHED_STOPPED: "stop", + RequestStatus.FINISHED_LENGTH_CAPPED: "length", + RequestStatus.FINISHED_ABORTED: "abort", + RequestStatus.FINISHED_IGNORED: "length", +} diff --git a/vllm/v1/sample/__init__.py b/vllm/v1/sample/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py new file mode 100644 index 0000000000000..28614377b27b9 --- /dev/null +++ b/vllm/v1/sample/metadata.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch + + +@dataclass +class SamplingMetadata: + + temperature: torch.Tensor + all_greedy: bool + all_random: bool + + top_p: torch.Tensor + top_k: torch.Tensor + no_top_p: bool + no_top_k: bool + + generators: List[Optional[torch.Generator]] + no_generator: bool + + max_num_logprobs: int diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py new file mode 100644 index 0000000000000..157c4dd6d771e --- /dev/null +++ b/vllm/v1/sample/sampler.py @@ -0,0 +1,161 @@ +"""A layer that samples the next tokens from the model's outputs.""" +from typing import List, Optional + +import torch +import torch.nn as nn + +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_top_k_top_p(logits, sampling_metadata) + + probs = self.get_probs(logits) + sampled = self.sample(probs, sampling_metadata) + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + + if sampling_metadata.max_num_logprobs > 0: + logprobs = self.get_logprobs(logits) + # FIXME: Mask the sampled token_id, get topk logprobs, + # and concatenate the topk with the sampled token_id. + topk_logprobs, topk_indices = torch.topk( + logprobs, sampling_metadata.max_num_logprobs, dim=-1) + # Use int32 to reduce the tensor size. + topk_indices = topk_indices.to(torch.int32) + else: + topk_logprobs = None + topk_indices = None + + sampler_output = SamplerOutput( + sampled_token_ids=sampled, + logprob_token_ids=topk_indices, + logprobs=topk_logprobs, + prompt_logprob_token_ids=None, + prompt_logprobs=None, + ) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + # Use float32 to apply temperature scaling. + logits = logits.to(torch.float32) + # Avoid division by zero. + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) + # Use in-place division to avoid creating a new tensor. + logits.div_(temp.unsqueeze(dim=1)) + return logits + + def apply_top_k_top_p( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + return _apply_top_k_top_p( + logits, + sampling_metadata.no_top_k, + sampling_metadata.top_k, + sampling_metadata.no_top_p, + sampling_metadata.top_p, + ) + + def get_probs(self, logits: torch.Tensor) -> torch.Tensor: + return torch.softmax(logits, dim=-1, dtype=torch.float32) + + def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return torch.log_softmax(logits, dim=-1, dtype=torch.float32) + + def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: + return probs.argmax(dim=-1).view(-1) + + def random_sample( + self, + probs: torch.Tensor, + generators: List[Optional[torch.Generator]], + no_generator: bool, + ) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + q.exponential_() + if not no_generator: + assert len(generators) == probs.shape[0] + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in enumerate(generators): + if generator is not None: + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + def sample( + self, + probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_greedy: + return self.greedy_sample(probs) + if sampling_metadata.all_random: + return self.random_sample(probs, sampling_metadata.generators, + sampling_metadata.no_generator) + + greedy_sampled = self.greedy_sample(probs) + random_sampled = self.random_sample(probs, + sampling_metadata.generators, + sampling_metadata.no_generator) + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + ) + return sampled + + +# TODO(woosuk): Optimize this with a custom kernel. +def _apply_top_k_top_p( + logits: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, +) -> torch.Tensor: + if no_top_k and no_top_p: + return logits + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if not no_top_k: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if not no_top_p: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits diff --git a/vllm/v1/tokenizer/__init__.py b/vllm/v1/tokenizer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py new file mode 100644 index 0000000000000..4bbcf4717981e --- /dev/null +++ b/vllm/v1/tokenizer/detokenizer.py @@ -0,0 +1,215 @@ +import multiprocessing +from dataclasses import dataclass +from typing import Dict, List, Optional + +import msgspec +import zmq +from msgspec import msgpack + +from vllm.transformers_utils.detokenizer_utils import ( + convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import get_open_port + + +class DetokenizerInputs(msgspec.Struct): + + # [num_reqs] + req_ids: List[str] + # A request's prompt token ids is sent to the detokenizer only when + # the request is first detokenized. Otherwise, an empty list is sent. + prompt_token_ids: List[List[int]] + new_token_ids: List[List[int]] + skip_special_tokens: List[bool] + spaces_between_special_tokens: List[bool] + + # [num_free_reqs] + free_req_ids: List[str] + + +class DetokenizerOutputs(msgspec.Struct): + + # [num_reqs] + req_ids: List[str] + detokenized_texts: List[str] + # NOTE(woosuk): The number of the output token ids of each request + # at the time of detokenization. The detokenizer returns this to the engine + # because the request state (including the output token ids) is + # asynchronously updated in the engine, while RequestOutput requires the + # output token ids to be consistent with the detokenized text. + num_output_token_ids: List[int] + + +class Detokenizer: + + def __init__(self, tokenizer_name: str): + # FIXME(woosuk): Currently, the detokenizer is just a hacky prototype. + # For example, it does not terminate properly. We need to improve this. + self.push_port = get_open_port() + self.pull_port = get_open_port() + self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port, + self.pull_port) + self.detokenizer.start() + + self.zmq_context = zmq.Context() + self.push_socket = self.zmq_context.socket(zmq.PUSH) + self.push_socket.connect(f"tcp://localhost:{self.push_port}") + self.pull_socket = self.zmq_context.socket(zmq.PULL) + self.pull_socket.connect(f"tcp://localhost:{self.pull_port}") + self.poller = zmq.Poller() + self.poller.register(self.pull_socket, zmq.POLLIN) + self.msgpack_encoder = msgpack.Encoder() + self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs) + + def send(self, inputs: DetokenizerInputs) -> None: + self.push_socket.send(self.msgpack_encoder.encode(inputs), + flags=zmq.NOBLOCK) + + def recv(self) -> Optional[DetokenizerOutputs]: + socks = dict(self.poller.poll(timeout=0)) + if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN: + msg = self.pull_socket.recv() + return self.msgpack_decoder.decode(msg) + return None + + def terminate(self) -> None: + self.push_socket.send(b"", flags=zmq.NOBLOCK) + self.detokenizer.join() + + +class DetokenizerProc(multiprocessing.Process): + + def __init__( + self, + tokenizer_name: str, + pull_port: int, + push_port: int, + ): + super().__init__() + self.tokenizer_name = tokenizer_name + # NOTE: The pull_port of the detokenizer should be the same as the + # push_port of the engine. Vice versa. + self.pull_port = pull_port + self.push_port = push_port + + def run(self): + # Initialize these objects after the process is forked since they are + # not picklable. + self.msgpack_encoder = msgpack.Encoder() + self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs) + self.tokenizer = get_tokenizer(self.tokenizer_name) + # req_id -> RequestState + self.request_states: Dict[str, RequestState] = {} + + self.zmq_context = zmq.Context() + self.pull_socket = self.zmq_context.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://*:{self.pull_port}") + self.push_socket = self.zmq_context.socket(zmq.PUSH) + self.push_socket.bind(f"tcp://*:{self.push_port}") + + while True: + message = self.pull_socket.recv() + if message == b"": + # Terminate signal. + break + inputs = self.msgpack_decoder.decode(message) + + for req_id in inputs.free_req_ids: + self.free(req_id) + + detokenized_texts: List[str] = [] + num_output_token_ids: List[int] = [] + num_reqs = len(inputs.req_ids) + for i in range(num_reqs): + req_id = inputs.req_ids[i] + if req_id not in self.request_states: + self.add_request( + request_id=req_id, + prompt_token_ids=inputs.prompt_token_ids[i], + skip_special_tokens=inputs.skip_special_tokens[i], + spaces_between_special_tokens=inputs. + spaces_between_special_tokens[i], + ) + new_str = self.detokenize(req_id, inputs.new_token_ids[i]) + detokenized_texts.append(new_str) + req_state = self.request_states[req_id] + num_output_token_ids.append( + len(req_state.token_ids) - req_state.num_prompt_tokens) + + detokenized = DetokenizerOutputs( + req_ids=inputs.req_ids, + detokenized_texts=detokenized_texts, + num_output_token_ids=num_output_token_ids, + ) + self.push_socket.send(self.msgpack_encoder.encode(detokenized), + flags=zmq.NOBLOCK) + + def add_request( + self, + request_id: str, + prompt_token_ids: List[int], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, + ) -> None: + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( + tokenizer=self.tokenizer, + prompt_ids=prompt_token_ids, + skip_special_tokens=skip_special_tokens, + ) + self.request_states[request_id] = RequestState( + req_id=request_id, + token_ids=prompt_token_ids, + tokens=tokens, + num_prompt_tokens=len(prompt_token_ids), + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + def free(self, request_id: str) -> None: + del self.request_states[request_id] + + def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. + req_state = self.request_states[request_id] + decoded_text = "" + for new_token_id in new_token_ids: + req_state.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=req_state.token_ids, + prev_tokens=req_state.tokens, + prefix_offset=req_state.prefix_offset, + read_offset=req_state.read_offset, + skip_special_tokens=req_state.skip_special_tokens, + spaces_between_special_tokens=req_state. + spaces_between_special_tokens, + ) + + req_state.tokens.extend(new_tokens) + req_state.prefix_offset = prefix_offset + req_state.read_offset = read_offset + req_state.output_text += new_decoded_token_text + decoded_text += new_decoded_token_text + return decoded_text + + +@dataclass +class RequestState: + + req_id: str + + token_ids: List[int] + tokens: List[str] + num_prompt_tokens: int + + prefix_offset: int + read_offset: int + + skip_special_tokens: bool + spaces_between_special_tokens: bool + + output_text: str = "" diff --git a/vllm/v1/worker/__init__.py b/vllm/v1/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py new file mode 100644 index 0000000000000..e84645ac7a4ae --- /dev/null +++ b/vllm/v1/worker/gpu_model_runner.py @@ -0,0 +1,690 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, + is_pin_memory_available) +from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionMetadata) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class GPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[torch.Tensor] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=req_data.sampling_params, + generator=None, # TODO + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table[:num_reqs].copy_( + self.input_batch.block_table_cpu_tensor[:num_reqs], + non_blocking=True) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + assert max_num_scheduled_tokens > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + indices = np.arange(num_reqs) + req_indices = np.repeat(indices, num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + (num_reqs, 1)) + mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + arange = arange_matrix[mask] + + # Get positions. + positions = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + positions_np = positions.numpy() + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.max_model_len + token_indices = torch.from_numpy(token_indices) + input_ids = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.index_select(torch.from_numpy( + self.input_batch.token_ids_cpu).flatten(), + 0, + token_indices, + out=input_ids) + + # Calculate the slot mapping. + block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ + token_indices // self.block_size] + block_offsets = token_indices % self.block_size + slot_mapping = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.add(block_numbers * self.block_size, + block_offsets, + out=slot_mapping) + + # Prepare the attention metadata. + query_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + + seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = seq_lens.max() + seq_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + seq_start_loc_np = seq_start_loc.numpy() + seq_start_loc_np[0] = 0 + np.cumsum(seq_lens, out=seq_start_loc_np[1:]) + + input_ids = input_ids.to(self.device, non_blocking=True) + positions = positions.to(self.device, non_blocking=True).long() + query_start_loc = query_start_loc.to(self.device, non_blocking=True) + seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) + slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + attn_metadata = FlashAttentionMetadata( + max_query_len=max_num_scheduled_tokens, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_start_loc=seq_start_loc, + block_table=self.input_batch.block_table[:num_reqs], + slot_mapping=slot_mapping, + ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return input_ids, positions, attn_metadata, logits_indices + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + return sampling_metadata + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + inputs = self._prepare_inputs(scheduler_output) + input_ids, positions, attn_metadata, logits_indices = inputs + + with set_forward_context(attn_metadata): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=self.kv_caches, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self._prepare_sampling(scheduler_output) + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + # NOTE: CPU-GPU synchronization happens here. + sampled_token_ids = sampler_output.sampled_token_ids.cpu() + sampled_token_ids_list = sampled_token_ids.tolist() + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + num_reqs = self.input_batch.num_reqs + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len <= req_state.num_tokens + if seq_len == req_state.num_tokens: + # Append the sampled token to the output token ids. + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators[i] + if generator is not None: + offset = generator.get_offset() + generator = generator.set_offset(offset - 1) + self.input_batch.generators[i] = generator + + if sampler_output.logprob_token_ids is None: + logprob_token_ids = None + else: + logprob_token_ids = sampler_output.logprob_token_ids.cpu() + if sampler_output.logprobs is None: + logprobs = None + else: + logprobs = sampler_output.logprobs.cpu() + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + return model_runner_output + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: + input_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + positions = torch.zeros(num_tokens, + dtype=torch.long, + device=self.device) + kv_caches = [None for _ in range(self.num_attn_layers)] + model(input_ids, positions, kv_caches, attn_metadata=None) + return + + @torch.inference_mode() + def profile_run(self) -> None: + self._dummy_run(self.model, self.max_num_tokens) + torch.cuda.synchronize() + return + + @torch.inference_mode() + def capture_model(self) -> None: + # TODO: Implement CUDA graph support. + return + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device)) + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + self.generators: List[Optional[torch.Generator]] = [None + ] * max_num_reqs + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + self.req_ids[req_index] = request.req_id + self.req_id_to_index[request.req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_index) + elif sampling_params.sampling_type == SamplingType.RANDOM: + self.random_reqs.add(req_index) + elif sampling_params.sampling_type == SamplingType.RANDOM_SEED: + # TODO(woosuk): Support per-request random seed. + raise NotImplementedError("Per-request seed is not supported yet.") + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_index) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_index) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[request.req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_index) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators[req_index] = None + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.generators[empty_index] = self.generators[last_req_index] + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators[:self.num_reqs], + no_generator=self.no_generator, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_generator(self) -> bool: + return len(self.generators) == 0 + + @property + def max_num_logprobs(self) -> int: + if self.num_logprobs: + return max(self.num_logprobs.values()) + else: + return 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py new file mode 100644 index 0000000000000..8c5ca2ec35666 --- /dev/null +++ b/vllm/v1/worker/gpu_worker.py @@ -0,0 +1,245 @@ +"""A GPU worker class.""" +import gc +import os +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class Worker: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + speculative_config: Optional[SpeculativeConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = GPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + ) + + def initialize(self): + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + # if self.model_runner.lora_manager: + # self.model_runner.remove_all_loras() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, 0 + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks.""" + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_gpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_gpu_blocks) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total