Skip to content

Commit

Permalink
[Frontend] Refactor prompt processing (vllm-project#4028)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
DarkLight1337 and ywang96 authored Jul 22, 2024
1 parent 89c1c6a commit 739b61a
Show file tree
Hide file tree
Showing 24 changed files with 698 additions and 390 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptStrictInputs
.. autodata:: vllm.inputs.PromptInputs

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
internally for each model.


To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/output_processor/test_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
("This text ends with EOS token", "</s>", 2),
])
@pytest.mark.parametrize("ignore_eos", [True, False, None])
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
@pytest.mark.parametrize("ignore_eos", [True, False])
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.skip_global_cleanup
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
ignore_eos: bool, include_stop_str_in_output: bool):
Expand Down
5 changes: 4 additions & 1 deletion tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ async def _async_serving_chat_init():
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion


Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version__",
"LLM",
"ModelRegistry",
"PromptStrictInputs",
"PromptInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
7 changes: 0 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser,
Expand All @@ -841,12 +840,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser


Expand Down
63 changes: 22 additions & 41 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -151,7 +151,10 @@ def process_exception(self,
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)

def add_request(self, request_id: str,
def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
Expand All @@ -166,6 +169,9 @@ def add_request(self, request_id: str,

self.new_requests_event.set()

if verbose:
logger.info("Added request %s.", request_id)

return stream

def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
Expand Down Expand Up @@ -299,14 +305,14 @@ async def process_model_inputs_async(
return self.input_processor(llm_inputs)

async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand Down Expand Up @@ -353,8 +359,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`.
Expand All @@ -368,13 +372,11 @@ def __init__(self,
engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)

self.background_loop: Optional[asyncio.Future] = None
Expand Down Expand Up @@ -468,7 +470,6 @@ def from_engine_args(
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
Expand Down Expand Up @@ -667,30 +668,9 @@ async def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")

max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]

logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)

if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
Expand All @@ -706,6 +686,7 @@ async def add_request(

stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
Expand All @@ -721,7 +702,7 @@ async def generate(
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Expand Down Expand Up @@ -804,7 +785,7 @@ async def encode(
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Expand Down Expand Up @@ -882,7 +863,7 @@ async def _process_request(
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union

Expand Down Expand Up @@ -522,7 +523,7 @@ def _add_processed_request(
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
Expand Down Expand Up @@ -603,7 +604,7 @@ def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
Expand Down Expand Up @@ -677,7 +678,7 @@ def _create_sequence_group_with_sampling(
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
Expand Down
Loading

0 comments on commit 739b61a

Please sign in to comment.