Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix server crash on empty prompt #7746

Merged
merged 10 commits into from
Aug 23, 2024
9 changes: 9 additions & 0 deletions tests/entrypoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from vllm import LLM


def test_empty_prompt():
llm = LLM(model="gpt2")
with pytest.raises(ValueError, match='Empty prompt'):
llm.generate([""])
22 changes: 22 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# imports for guided decoding tests
import re

import openai
import pytest

from ...utils import RemoteOpenAIServer


@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.BadRequestError,
match=re.compile('.+Empty prompt.+')):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
temperature=0.0)
12 changes: 12 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def _add_processed_request(
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -1647,3 +1648,14 @@ def is_encoder_decoder_model(self):

def is_embedding_model(self):
return self.model_config.is_embedding_model

def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
if self.is_encoder_decoder_model():
if "encoder_prompt_token_ids" not in inputs or\
len(inputs["encoder_prompt_token_ids"]) == 0:
raise ValueError("Empty prompt")
else:
if "prompt_token_ids" not in inputs or len(
inputs["prompt_token_ids"]) == 0:
raise ValueError("Empty prompt")
njhill marked this conversation as resolved.
Show resolved Hide resolved
Loading