From 9b9a10d6cb89f18e054daa66f25cb8f17c723b2c Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 22 May 2024 05:32:35 +0000 Subject: [PATCH] [Frontend] Dynamic RoPE scaling (#4638) --- tests/test_config.py | 56 ++++++++++++++++++++++++++++++- vllm/config.py | 7 +++- vllm/engine/arg_utils.py | 18 +++++++--- vllm/engine/llm_engine.py | 10 +++--- vllm/transformers_utils/config.py | 10 +++++- 5 files changed, 89 insertions(+), 12 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 19db10630bbae..6bc51a53dc07c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -36,4 +36,58 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + +def test_rope_scaling(): + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert llama_model_config.max_model_len == 8192 + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert llama_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == LONGCHAT_ROPE_SCALING + assert longchat_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 4096 diff --git a/vllm/config.py b/vllm/config.py index 44ed5635f9a35..3256c11967914 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -45,6 +45,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -84,6 +87,7 @@ def __init__( seed: int, revision: Optional[str] = None, code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -102,6 +106,7 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.quantization_param_path = quantization_param_path @@ -116,7 +121,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ba424c4eeb14..0a9ec7472fbca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,5 +1,6 @@ import argparse import dataclasses +import json from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -49,6 +50,7 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -330,6 +332,11 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig: model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.dtype, self.seed, self.revision, - self.code_revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, - self.max_seq_len_to_capture, self.max_logprobs, - self.skip_tokenizer_init, self.served_model_name) + self.code_revision, self.rope_scaling, self.tokenizer_revision, + self.max_model_len, self.quantization, + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_seq_len_to_capture, + self.max_logprobs, self.skip_tokenizer_init, + self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f6a5284093c1c..60e23d4df15bb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,10 +104,11 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " - "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " - "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "rope_scaling=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, @@ -117,6 +118,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.rope_scaling, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1756c91a612f0..f36d84dbdf7f9 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,9 +2,12 @@ from transformers import AutoConfig, PretrainedConfig +from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig) +logger = init_logger(__name__) + _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -18,7 +21,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, @@ -41,6 +45,10 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + if rope_scaling is not None: + logger.info("Updating rope_scaling from %r to %r", + getattr(config, "rope_scaling", None), rope_scaling) + config.update({"rope_scaling": rope_scaling}) return config