From 4b33277ef67277ad29a3249bb1e501cd893bf3ce Mon Sep 17 00:00:00 2001 From: Jeff Fialho Date: Sat, 3 Aug 2024 20:01:38 -0300 Subject: [PATCH] [Frontend] Warn if user `max_model_len` is greater than derived `max_model_len` (#7080) Signed-off-by: Jefferson Fialho Co-authored-by: Nick Hill Signed-off-by: Alvant --- vllm/config.py | 19 +++++++++++++------ vllm/envs.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ef56e2b6395be..028f4eed8f4a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,6 +6,7 @@ import torch from transformers import PretrainedConfig +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry @@ -1541,15 +1542,21 @@ def _get_and_verify_max_len( "Disabling sliding window is not supported for models " "model_max_length in the config. Please raise an issue " "so we can investigate.") - pass else: - raise ValueError( + msg = ( f"User-specified max_model_len ({max_model_len}) is greater " - "than the derived max_model_len " - f"({max_len_key}={derived_max_model_len} or model_max_length=" + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors. Make sure the " - "value is correct and within the model context size.") + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") return int(max_model_len) diff --git a/vllm/envs.py b/vllm/envs.py index a78bad6a2b273..089a39d8e029d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -50,6 +50,7 @@ VLLM_NO_DEPRECATION_WARNING: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False def get_default_cache_root(): @@ -331,6 +332,15 @@ def get_default_config_root(): # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition