From 1b74b46621a20bc56dc0d8e2b120cfc418d403c5 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 21 Aug 2024 14:50:44 -0300 Subject: [PATCH] Commit from ustream PR #7746 Validate the that the input prompts aren't empty This avoids an async loop crash that takes down the server Signed-off-by: Max de Bayser Signed-off-by: Jefferson Fialho --- vllm/engine/async_llm_engine.py | 2 ++ vllm/engine/llm_engine.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0e17696724198..88716df69e806 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -543,6 +543,7 @@ async def process_model_inputs_async( inputs, request_id=request_id, ) + self._validate_enc_dec_inputs(model_inputs) else: if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " @@ -555,6 +556,7 @@ async def process_model_inputs_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + self._validate_dec_only_inputs(model_inputs) return self.input_processor(model_inputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 36cb6ce795f3e..54676f19ca866 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -946,6 +946,7 @@ def process_model_inputs( inputs, request_id=request_id, ) + self._validate_enc_dec_inputs(model_inputs) else: if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " @@ -958,6 +959,7 @@ def process_model_inputs( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + self._validate_dec_only_inputs(model_inputs) return self.input_processor(model_inputs) @@ -1631,3 +1633,13 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model + + def _validate_dec_only_inputs(self, inputs: LLMInputs): + if "prompt_token_ids" not in inputs or len( + inputs["prompt_token_ids"]) == 0: + raise ValueError("Empty prompt") + + def _validate_enc_dec_inputs(self, inputs: EncoderDecoderLLMInputs): + if "encoder_prompt_token_ids" not in inputs or\ + len(inputs["encoder_prompt_token_ids"]) == 0: + raise ValueError("Empty prompt")