Skip to content

Commit

Permalink
Commit from ustream PR vllm-project#7746
Browse files Browse the repository at this point in the history
Validate the that the input prompts aren't empty

This avoids an async loop crash that takes down the server

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Jefferson Fialho <jfialho@ibm.com>
  • Loading branch information
maxdebayser authored and fialhocoelho committed Aug 22, 2024
1 parent 208344f commit 1b74b46
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ async def process_model_inputs_async(
inputs,
request_id=request_id,
)
self._validate_enc_dec_inputs(model_inputs)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
Expand All @@ -555,6 +556,7 @@ async def process_model_inputs_async(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
self._validate_dec_only_inputs(model_inputs)

return self.input_processor(model_inputs)

Expand Down
12 changes: 12 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ def process_model_inputs(
inputs,
request_id=request_id,
)
self._validate_enc_dec_inputs(model_inputs)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
Expand All @@ -958,6 +959,7 @@ def process_model_inputs(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
self._validate_dec_only_inputs(model_inputs)

return self.input_processor(model_inputs)

Expand Down Expand Up @@ -1631,3 +1633,13 @@ def is_encoder_decoder_model(self):

def is_embedding_model(self):
return self.model_config.is_embedding_model

def _validate_dec_only_inputs(self, inputs: LLMInputs):
if "prompt_token_ids" not in inputs or len(
inputs["prompt_token_ids"]) == 0:
raise ValueError("Empty prompt")

def _validate_enc_dec_inputs(self, inputs: EncoderDecoderLLMInputs):
if "encoder_prompt_token_ids" not in inputs or\
len(inputs["encoder_prompt_token_ids"]) == 0:
raise ValueError("Empty prompt")

0 comments on commit 1b74b46

Please sign in to comment.