Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Warn if user max_model_len is greater than derived max_model_len #7080

Merged
merged 7 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
Expand Down Expand Up @@ -1541,15 +1542,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(
msg = (
f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len "
f"({max_len_key}={derived_max_model_len} or model_max_length="
f"than the derived max_model_len ({max_len_key}="
f"{derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
"to incorrect model outputs or CUDA errors.")
if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
logger.warning(
"%s Make sure the value is correct and within the "
"model context size.", msg)
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
return int(max_model_len)


Expand Down
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
VLLM_NO_DEPRECATION_WARNING: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -331,6 +332,15 @@ def get_default_config_root():
# If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),

# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
"VLLM_ALLOW_LONG_MAX_MODEL_LEN":
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down
Loading