diff --git a/engines/python/setup/djl_python/properties_manager/properties.py b/engines/python/setup/djl_python/properties_manager/properties.py index 2b17c84db..85ed64f79 100644 --- a/engines/python/setup/djl_python/properties_manager/properties.py +++ b/engines/python/setup/djl_python/properties_manager/properties.py @@ -61,9 +61,9 @@ class Properties(BaseModel): input_formatter: Optional[Callable] = None waiting_steps: Optional[int] = None mpi_mode: bool = False - tgi_compat: Optional[bool] = False - bedrock_compat: Optional[bool] = False - enable_lora: Optional[bool] = False + tgi_compat: bool = False + bedrock_compat: bool = False + enable_lora: bool = False # Spec_dec draft_model_id: Optional[str] = None diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index 6e097e3e2..f93b0685b 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -11,66 +11,72 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. import ast -from enum import Enum -from typing import Optional, Any, Mapping, Tuple, Dict - -from pydantic import field_validator, model_validator +import logging +from typing import Optional, Any, Dict, Tuple +from pydantic import field_validator, model_validator, ConfigDict, Field +from vllm import EngineArgs +from vllm.utils import FlexibleArgumentParser +from vllm.engine.arg_utils import StoreBoolean from djl_python.properties_manager.properties import Properties +DTYPE_MAPPER = { + "float32": "float32", + "fp32": "float32", + "float16": "float16", + "fp16": "float16", + "bfloat16": "bfloat16", + "bf16": "bfloat16", + "auto": "auto" +} + + +def construct_vllm_args_list(vllm_engine_args: dict, + parser: FlexibleArgumentParser): + # Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258 + args_list = [] + store_boolean_arguments = { + action.dest + for action in parser._actions if isinstance(action, StoreBoolean) + } + for engine_arg, engine_arg_value in vllm_engine_args.items(): + if str(engine_arg_value).lower() in { + 'true', 'false' + } and engine_arg not in store_boolean_arguments: + if str(engine_arg_value).lower() == 'true': + args_list.append(f"--{engine_arg}") + else: + args_list.append(f"--{engine_arg}={engine_arg_value}") + return args_list + class VllmRbProperties(Properties): engine: Optional[str] = None - dtype: Optional[str] = "auto" - load_format: Optional[str] = "auto" - quantize: Optional[str] = None + # The following configs have different names in DJL compared to vLLM, we only accept DJL name currently tensor_parallel_degree: int = 1 pipeline_parallel_degree: int = 1 - max_rolling_batch_prefill_tokens: Optional[int] = None - # Adjustable prefix model length for certain 32k or longer model - max_model_len: Optional[int] = None - enforce_eager: Optional[bool] = False - # TODO: this default may change with different vLLM versions - # TODO: try to get good default from vLLM to prevent revisiting - # TODO: last time check: vllm 0.3.1 - gpu_memory_utilization: Optional[float] = 0.9 - enable_lora: Optional[bool] = False - max_loras: Optional[int] = 4 - max_lora_rank: Optional[int] = 16 - fully_sharded_loras: bool = False - lora_extra_vocab_size: int = 256 + # The following configs have different names in DJL compared to vLLM, either is accepted + quantize: Optional[str] = Field(alias="quantization", + default=EngineArgs.quantization) + max_rolling_batch_prefill_tokens: Optional[int] = Field( + alias="max_num_batched_tokens", + default=EngineArgs.max_num_batched_tokens) + cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb", + default=EngineArgs.cpu_offload_gb) + # The following configs have different defaults, or additional processing in DJL compared to vLLM + dtype: str = "auto" + max_loras: int = 4 + # The following configs have broken processing in vllm via the FlexibleArgumentParser long_lora_scaling_factors: Optional[Tuple[float, ...]] = None - lora_dtype: Optional[str] = 'auto' - max_cpu_loras: Optional[int] = None + use_v2_block_manager: bool = True # Neuron vLLM properties - device: Optional[str] = None + device: str = 'auto' preloaded_model: Optional[Any] = None generation_config: Optional[Any] = None - max_logprobs: Optional[int] = 20 - enable_chunked_prefill: Optional[bool] = None - cpu_offload_gb_per_gpu: Optional[int] = 0 - enable_prefix_caching: Optional[bool] = False - disable_sliding_window: Optional[bool] = False - limit_mm_per_prompt: Optional[Mapping[str, int]] = None - use_v2_block_manager: bool = False - tokenizer_mode: str = 'auto' - - # Speculative decoding configuration. - speculative_model: Optional[str] = None - speculative_model_quantization: Optional[str] = None - speculative_draft_tensor_parallel_size: Optional[int] = None - num_speculative_tokens: Optional[int] = None - speculative_max_model_len: Optional[int] = None - speculative_disable_by_batch_size: Optional[int] = None - ngram_prompt_lookup_max: Optional[int] = None - ngram_prompt_lookup_min: Optional[int] = None - spec_decoding_acceptance_method: str = 'rejection_sampler' - typical_acceptance_sampler_posterior_threshold: Optional[float] = None - typical_acceptance_sampler_posterior_alpha: Optional[float] = None - qlora_adapter_name_or_path: Optional[str] = None - disable_logprobs_during_spec_decoding: Optional[bool] = None + # This allows generic vllm engine args to be passed in and set with vllm + model_config = ConfigDict(extra='allow', populate_by_name=True) @field_validator('engine') def validate_engine(cls, engine): @@ -79,6 +85,24 @@ def validate_engine(cls, engine): f"Need python engine to start vLLM RollingBatcher") return engine + @field_validator('dtype') + def validate_dtype(cls, val): + if val not in DTYPE_MAPPER: + raise ValueError( + f"Invalid dtype={val} provided. Must be one of {DTYPE_MAPPER.keys()}" + ) + return DTYPE_MAPPER[val] + + @model_validator(mode='after') + def validate_pipeline_parallel(self): + if self.pipeline_parallel_degree != 1: + raise ValueError( + "Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation" + ) + return self + + # TODO: processing of this field is broken in vllm via from_cli_args + # we should upstream a fix for this to vllm @field_validator('long_lora_scaling_factors', mode='before') def validate_long_lora_scaling_factors(cls, val): if isinstance(val, str): @@ -96,39 +120,75 @@ def validate_long_lora_scaling_factors(cls, val): ) return val - @field_validator('limit_mm_per_prompt', mode="before") - def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]: - out_dict: Dict[str, int] = {} - for item in val.split(","): - kv_parts = [part.lower().strip() for part in item.split("=")] - if len(kv_parts) != 2: - raise ValueError("Each item should be in the form key=value") - key, value = kv_parts - - try: - parsed_value = int(value) - except ValueError as e: - raise ValueError( - f"Failed to parse value of item {key}={value}") from e + def handle_lmi_vllm_config_conflicts(self, additional_vllm_engine_args): - if key in out_dict and out_dict[key] != parsed_value: - raise ValueError( - f"Conflicting values specified for key: {key}") - out_dict[key] = parsed_value - return out_dict + def validate_potential_lmi_vllm_config_conflict( + lmi_config_name, vllm_config_name): + lmi_config_val = self.__getattribute__(lmi_config_name) + vllm_config_val = additional_vllm_engine_args.get(vllm_config_name) + if vllm_config_val is not None and lmi_config_val is not None: + if vllm_config_val != lmi_config_val: + raise ValueError( + f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values." + f"We currently only accept the DJL config {lmi_config_name}, please remove the vllm {vllm_config_name} configuration." + ) - @model_validator(mode='after') - def validate_speculative_model(self): - if self.speculative_model is not None and not self.use_v2_block_manager: - raise ValueError( - "Speculative decoding requires usage of the V2 block manager. Enable it with option.use_v2_block_manager=true." - ) - return self + validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree", + "tensor_parallel_size") + validate_potential_lmi_vllm_config_conflict("pipeline_parallel_degree", + "pipeline_parallel_size") + validate_potential_lmi_vllm_config_conflict("max_rolling_batch_size", + "max_num_seqs") - @model_validator(mode='after') - def validate_pipeline_parallel(self): - if self.pipeline_parallel_degree != 1: - raise ValueError( - "Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation" - ) - return self + def generate_vllm_engine_arg_dict(self, + passthrough_vllm_engine_args) -> dict: + vllm_engine_args = { + 'model': self.model_id_or_path, + 'tensor_parallel_size': self.tensor_parallel_degree, + 'pipeline_parallel_size': self.pipeline_parallel_degree, + 'max_num_seqs': self.max_rolling_batch_size, + 'dtype': DTYPE_MAPPER[self.dtype], + 'revision': self.revision, + 'max_loras': self.max_loras, + 'enable_lora': self.enable_lora, + 'trust_remote_code': self.trust_remote_code, + 'cpu_offload_gb': self.cpu_offload_gb_per_gpu, + 'use_v2_block_manager': self.use_v2_block_manager, + 'quantization': self.quantize, + 'device': self.device, + } + if self.max_rolling_batch_prefill_tokens is not None: + vllm_engine_args[ + 'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens + if self.device == 'neuron': + vllm_engine_args['block_size'] = passthrough_vllm_engine_args.get( + "max_model_len") + vllm_engine_args.update(passthrough_vllm_engine_args) + return vllm_engine_args + + def get_engine_args(self) -> EngineArgs: + additional_vllm_engine_args = self.get_additional_vllm_engine_args() + self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args) + vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict( + additional_vllm_engine_args) + logging.debug( + f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}" + ) + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser) + args = parser.parse_args(args=args_list) + engine_args = EngineArgs.from_cli_args(args) + # we have to do this separately because vllm converts it into a string + engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors + # These neuron configs are not implemented in the vllm arg parser + if self.device == 'neuron': + setattr(engine_args, 'preloaded_model', self.preloaded_model) + setattr(engine_args, 'generation_config', self.generation_config) + return engine_args + + def get_additional_vllm_engine_args(self) -> Dict[str, Any]: + return { + k: v + for k, v in self.__pydantic_extra__.items() + if k in EngineArgs.__annotations__ + } diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb..b68fb9f91 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -235,75 +235,6 @@ def get_lora_request(lora_name: str, lora_requests: dict) -> dict: return lora_requests[lora_name] -def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs: - if config.device == "neuron": - return EngineArgs(model=config.model_id_or_path, - preloaded_model=config.preloaded_model, - tensor_parallel_size=config.tensor_parallel_degree, - dtype=DTYPE_MAPPER[config.dtype], - seed=0, - max_model_len=config.max_model_len, - max_num_seqs=config.max_rolling_batch_size, - block_size=config.max_model_len, - trust_remote_code=config.trust_remote_code, - revision=config.revision, - device=config.device, - generation_config=config.generation_config) - else: - return EngineArgs( - model=config.model_id_or_path, - tensor_parallel_size=config.tensor_parallel_degree, - pipeline_parallel_size=config.pipeline_parallel_degree, - dtype=DTYPE_MAPPER[config.dtype], - seed=0, - max_model_len=config.max_model_len, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - max_num_batched_tokens=config.max_rolling_batch_prefill_tokens, - trust_remote_code=config.trust_remote_code, - load_format=config.load_format, - quantization=config.quantize, - enable_lora=config.enable_lora, - max_loras=config.max_loras, - max_lora_rank=config.max_lora_rank, - fully_sharded_loras=config.fully_sharded_loras, - lora_extra_vocab_size=config.lora_extra_vocab_size, - long_lora_scaling_factors=config.long_lora_scaling_factors, - lora_dtype=config.lora_dtype, - max_cpu_loras=config.max_cpu_loras, - revision=config.revision, - max_logprobs=config.max_logprobs, - enable_chunked_prefill=config.enable_chunked_prefill, - cpu_offload_gb=config.cpu_offload_gb_per_gpu, - enable_prefix_caching=config.enable_prefix_caching, - disable_sliding_window=config.disable_sliding_window, - max_num_seqs=config.max_rolling_batch_size, - use_v2_block_manager=config.use_v2_block_manager, - speculative_model=config.speculative_model, - speculative_model_quantization=config. - speculative_model_quantization, - speculative_draft_tensor_parallel_size=config. - speculative_draft_tensor_parallel_size, - num_speculative_tokens=config.num_speculative_tokens, - speculative_max_model_len=config.speculative_max_model_len, - speculative_disable_by_batch_size=config. - speculative_disable_by_batch_size, - ngram_prompt_lookup_max=config.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=config.ngram_prompt_lookup_min, - spec_decoding_acceptance_method=config. - spec_decoding_acceptance_method, - typical_acceptance_sampler_posterior_threshold=config. - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=config. - typical_acceptance_sampler_posterior_alpha, - qlora_adapter_name_or_path=config.qlora_adapter_name_or_path, - disable_logprobs_during_spec_decoding=config. - disable_logprobs_during_spec_decoding, - limit_mm_per_prompt=config.limit_mm_per_prompt, - tokenizer_mode=config.tokenizer_mode, - ) - - def get_multi_modal_data(request: Request) -> Optional[dict]: parameters = request.parameters images = parameters.pop("images", None) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811..d49b753bd 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -19,7 +19,7 @@ from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params from djl_python.rolling_batch.rolling_batch_vllm_utils import ( update_request_cache_with_output, create_lora_request, get_lora_request, - get_engine_args_from_config, get_prompt_inputs) + get_prompt_inputs) from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties from typing import List, Optional @@ -47,12 +47,12 @@ def __init__(self, model_id_or_path: str, properties: dict, """ self.vllm_configs = VllmRbProperties(**properties) super().__init__(self.vllm_configs) - args = get_engine_args_from_config(self.vllm_configs) + args = self.vllm_configs.get_engine_args() self.engine = LLMEngine.from_engine_args(args) self.request_cache = OrderedDict() self.lora_id_counter = AtomicCounter(0) self.lora_requests = {} - self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral' + self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral' def get_tokenizer(self): return self.engine.tokenizer.tokenizer diff --git a/engines/python/setup/djl_python/tests/test_properties_manager.py b/engines/python/setup/djl_python/tests/test_properties_manager.py index 3421a2ee2..06e6fb88e 100644 --- a/engines/python/setup/djl_python/tests/test_properties_manager.py +++ b/engines/python/setup/djl_python/tests/test_properties_manager.py @@ -3,6 +3,8 @@ import unittest from unittest import mock +from vllm import EngineArgs + from djl_python.properties_manager.properties import Properties from djl_python.properties_manager.tnx_properties import ( TransformerNeuronXProperties, TnXGenerationStrategy, TnXModelSchema, @@ -421,85 +423,226 @@ def test_hf_error_case(self, params): HuggingFaceProperties(**params) def test_vllm_properties(self): - # test with valid vllm properties - def test_vllm_valid(properties): - vllm_configs = VllmRbProperties(**properties) - self.assertEqual(vllm_configs.model_id_or_path, - properties['model_id']) - self.assertEqual(vllm_configs.engine, properties['engine']) + def validate_vllm_config_and_engine_args_match( + vllm_config_value, + engine_arg_value, + expected_value, + ): + self.assertEqual(vllm_config_value, expected_value) + self.assertEqual(engine_arg_value, expected_value) + + def test_vllm_default_properties(): + required_properties = { + "engine": "Python", + "model_id": "some_model", + } + vllm_configs = VllmRbProperties(**required_properties) + engine_args = vllm_configs.get_engine_args() + validate_vllm_config_and_engine_args_match( + vllm_configs.model_id_or_path, engine_args.model, "some_model") + validate_vllm_config_and_engine_args_match( + vllm_configs.tensor_parallel_degree, + engine_args.tensor_parallel_size, 1) + validate_vllm_config_and_engine_args_match( + vllm_configs.pipeline_parallel_degree, + engine_args.pipeline_parallel_size, 1) + validate_vllm_config_and_engine_args_match( + vllm_configs.quantize, engine_args.quantization, None) + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_size, engine_args.max_num_seqs, + 32) + validate_vllm_config_and_engine_args_match(vllm_configs.dtype, + engine_args.dtype, + 'auto') + validate_vllm_config_and_engine_args_match(vllm_configs.max_loras, + engine_args.max_loras, + 4) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, EngineArgs.cpu_offload_gb) self.assertEqual( + len(vllm_configs.get_additional_vllm_engine_args()), 0) + + def test_invalid_pipeline_parallel(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": "4", + "pipeline_parallel_degree": "2", + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + def test_invalid_engine(): + properties = { + "engine": "bad_engine", + "model_id": "some_model", + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + def test_aliases(): + properties = { + "engine": "Python", + "model_id": "some_model", + "quantization": "awq", + "max_num_batched_tokens": "546", + "cpu_offload_gb": "7" + } + vllm_configs = VllmRbProperties(**properties) + engine_args = vllm_configs.get_engine_args() + validate_vllm_config_and_engine_args_match( + vllm_configs.quantize, engine_args.quantization, "awq") + validate_vllm_config_and_engine_args_match( vllm_configs.max_rolling_batch_prefill_tokens, - int(properties['max_rolling_batch_prefill_tokens'])) - self.assertEqual(vllm_configs.dtype, properties['dtype']) - self.assertEqual(vllm_configs.load_format, - properties['load_format']) - self.assertEqual(vllm_configs.quantize, properties['quantize']) - self.assertEqual(vllm_configs.tensor_parallel_degree, - int(properties['tensor_parallel_degree'])) - self.assertEqual(vllm_configs.max_model_len, - int(properties['max_model_len'])) - self.assertEqual(vllm_configs.enforce_eager, - bool(properties['enforce_eager'])) - self.assertEqual(vllm_configs.enable_lora, - bool(properties['enable_lora'])) - self.assertEqual(vllm_configs.gpu_memory_utilization, - float(properties['gpu_memory_utilization'])) + engine_args.max_num_batched_tokens, 546) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, 7) - def test_enforce_eager(properties): - properties.pop('enforce_eager') - properties.pop('quantize') - self.assertTrue("enforce_eager" not in properties) - vllm_props = VllmRbProperties(**properties) - self.assertTrue(vllm_props.enforce_eager is False) + def test_vllm_passthrough_properties(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": "4", + "pipeline_parallel_degree": "1", + "max_rolling_batch_size": "111", + "quantize": "awq", + "max_rolling_batch_prefill_tokens": "400", + "cpu_offload_gb_per_gpu": "8", + "dtype": "bf16", + "max_loras": "7", + "long_lora_scaling_factors": "1.1, 2.0", + "trust_remote_code": "true", + "max_model_len": "1024", + "enforce_eager": "true", + "enable_chunked_prefill": "False", + "gpu_memory_utilization": "0.4", + } + vllm_configs = VllmRbProperties(**properties) + engine_args = vllm_configs.get_engine_args() + self.assertTrue( + len(vllm_configs.get_additional_vllm_engine_args()) > 0) + validate_vllm_config_and_engine_args_match( + vllm_configs.model_id_or_path, engine_args.model, "some_model") + validate_vllm_config_and_engine_args_match( + vllm_configs.tensor_parallel_degree, + engine_args.tensor_parallel_size, 4) + validate_vllm_config_and_engine_args_match( + vllm_configs.pipeline_parallel_degree, + engine_args.pipeline_parallel_size, 1) + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_size, engine_args.max_num_seqs, + 111) + validate_vllm_config_and_engine_args_match( + vllm_configs.quantize, engine_args.quantization, "awq") + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_prefill_tokens, + engine_args.max_num_batched_tokens, 400) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, 8.0) + validate_vllm_config_and_engine_args_match(vllm_configs.dtype, + engine_args.dtype, + "bfloat16") + validate_vllm_config_and_engine_args_match(vllm_configs.max_loras, + engine_args.max_loras, + 7) + validate_vllm_config_and_engine_args_match( + vllm_configs.long_lora_scaling_factors, + engine_args.long_lora_scaling_factors, (1.1, 2.0)) + validate_vllm_config_and_engine_args_match( + vllm_configs.trust_remote_code, engine_args.trust_remote_code, + True) + self.assertEqual(engine_args.max_model_len, 1024) + self.assertEqual(engine_args.enforce_eager, True) + self.assertEqual(engine_args.enable_chunked_prefill, False) + self.assertEqual(engine_args.gpu_memory_utilization, 0.4) - def test_long_lora_scaling_factors(properties): - properties['long_lora_scaling_factors'] = "3.0" + def test_long_lora_scaling_factors(): + properties = { + "engine": "Python", + "model_id": "some_model", + "long_lora_scaling_factors": "3.0" + } vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "3" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "3.0,4.0" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) properties['long_lora_scaling_factors'] = "3.0, 4.0 " vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) properties['long_lora_scaling_factors'] = "(3.0,)" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "(3.0,4.0)" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) - def test_invalid_long_lora_scaling_factors(properties): - properties['long_lora_scaling_factors'] = "a,b" + def test_invalid_long_lora_scaling_factors(): + properties = { + "engine": "Python", + "model_id": "some_model", + "long_lora_scaling_factors": "a,b" + } with self.assertRaises(ValueError): - VllmRbProperties(**properties) + _ = VllmRbProperties(**properties) - properties = { - 'model_id': 'sample_model_id', - 'engine': 'Python', - 'max_rolling_batch_prefill_tokens': '12500', - 'max_model_len': '12800', - 'tensor_parallel_degree': '2', - 'dtype': 'fp16', - 'quantize': 'awq', - 'enforce_eager': "True", - 'enable_lora': "true", - "gpu_memory_utilization": "0.85", - 'load_format': 'pt' - } - test_vllm_valid(properties.copy()) - test_enforce_eager(properties.copy()) - test_long_lora_scaling_factors(properties.copy()) - test_invalid_long_lora_scaling_factors(properties.copy()) + def test_conflicting_djl_vllm_conflicts(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": 2, + "tensor_parallel_size": 1, + } + vllm_configs = VllmRbProperties(**properties) + with self.assertRaises(ValueError): + vllm_configs.get_engine_args() + + properties = { + "engine": "Python", + "model_id": "some_model", + "pipeline_parallel_degree": 1, + "pipeline_parallel_size": 0, + } + vllm_configs = VllmRbProperties(**properties) + with self.assertRaises(ValueError): + vllm_configs.get_engine_args() + + properties = { + "engine": "Python", + "model_id": "some_model", + "max_rolling_batch_size": 1, + "max_num_seqs": 2, + } + vllm_configs = VllmRbProperties(**properties) + with self.assertRaises(ValueError): + vllm_configs.get_engine_args() + + test_vllm_default_properties() + test_invalid_pipeline_parallel() + test_invalid_engine() + test_aliases() + test_vllm_passthrough_properties() + test_long_lora_scaling_factors() + test_invalid_long_lora_scaling_factors() + test_conflicting_djl_vllm_conflicts() def test_sd_inf2_properties(self): properties = { diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 6883bbf89..0b625879b 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -58,7 +58,7 @@ def run(self): test_requirements = [ 'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', - 'pydantic>=2.0', "objgraph" + 'pydantic>=2.0', "objgraph", 'vllm==0.6.3.post1' ] setup(name='djl_python', diff --git a/serving/docs/lmi/user_guides/vllm_user_guide.md b/serving/docs/lmi/user_guides/vllm_user_guide.md index 5206eba0f..4952c5b83 100644 --- a/serving/docs/lmi/user_guides/vllm_user_guide.md +++ b/serving/docs/lmi/user_guides/vllm_user_guide.md @@ -8,11 +8,11 @@ vLLM expects the model artifacts to be in the [standard HuggingFace format](../d **Text Generation Models** -Here is the list of text generation models supported in [vllm 0.6.2](https://docs.vllm.ai/en/v0.6.2/models/supported_models.html#decoder-only-language-models). +Here is the list of text generation models supported in [vllm 0.6.3.post1](https://docs.vllm.ai/en/v0.6.3.post1/models/supported_models.html#decoder-only-language-models). **Multi Modal Models** -Here is the list of multi-modal models supported in [vllm 0.6.2](https://docs.vllm.ai/en/v0.6.2/models/supported_models.html#decoder-only-language-models). +Here is the list of multi-modal models supported in [vllm 0.6.3.post1](https://docs.vllm.ai/en/v0.6.3.post1/models/supported_models.html#decoder-only-language-models). ### Model Coverage in CI @@ -34,10 +34,14 @@ The following set of models are tested in our nightly tests ## Quantization Support -The quantization techniques supported in vLLM 0.6.2 are listed [here](https://docs.vllm.ai/en/v0.6.2/quantization/supported_hardware.html). +The quantization techniques supported in vLLM 0.6.3.post1 are listed [here](https://docs.vllm.ai/en/v0.6.3.post1/quantization/supported_hardware.html). -We highly recommend that regardless of which quantization technique you are using that you pre-quantize the model. -Runtime quantization adds additional overhead to the endpoint startup time, and depending on the quantization technique, this can be significant overhead. +We recommend that regardless of which quantization technique you are using that you pre-quantize the model. +Runtime quantization adds additional overhead to the endpoint startup time. +Depending on the quantization technique, this can be significant overhead. +If you are using a pre-quantized model, you should not set any quantization specific configurations. +vLLM will deduce the quantization from the model config, and apply optimizations at runtime. +If you explicitly set the quantization configuration for a pre-quantized model, it limits the optimizations that vLLM can apply. The following quantization techniques are supported for runtime quantization: @@ -47,7 +51,7 @@ The following quantization techniques are supported for runtime quantization: You can leverage these techniques by specifying `option.quantize=` in serving.properties, or `OPTION_QUANTIZE=` environment variable. Other quantization techniques supported by vLLM require ahead of time quantization to be served with LMI. -You can find details on how to leverage those quantization techniques from the vLLM docs [here](https://docs.vllm.ai/en/v0.6.2/quantization/supported_hardware.html). +You can find details on how to leverage those quantization techniques from the vLLM docs [here](https://docs.vllm.ai/en/v0.6.3.post1/quantization/supported_hardware.html). ## Quick Start Configurations @@ -82,7 +86,13 @@ You can follow [this example](../deployment_guide/deploying-your-endpoint.md#con vLLM has support for LoRA adapters using the [adapters API](../../adapters.md). In order to use the adapters, you must begin by enabling them by setting `option.enable_lora=true`. -Following that, you can configure the LoRA support through the additional settings `option.max_loras`, `option.max_lora_rank`, `option.lora_extra_vocab_size`, and `option.max_cpu_loras`. +Following that, you can configure the LoRA support through the additional settings: + + - `option.max_loras` + - `option.max_lora_rank` + - `option.lora_extra_vocab_size` + - `option.max_cpu_loras` + - If you run into OOM by enabling adapter support, reduce the `option.gpu_memory_utilization`. ### Advanced vLLM Configurations @@ -96,22 +106,35 @@ For `LMI` configurations, if we determine an issue with the configuration, we wi For `Pass Through` configurations it is possible that our investigation reveals an issue with the backend library. In that situation, there is nothing LMI can do until the issue is fixed in the backend library. -| Item | LMI Version | Configuration Type | Description | Example value | -|-----------------------------------------|-------------|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------| -| option.quantize | \>= 0.26.0 | LMI | Quantize the model with the supported quantization methods. LMI uses this to set the right quantization configs in VLLM | `awq` Default: `None` | -| option.max_rolling_batch_prefill_tokens | \>= 0.26.0 | LMI | Limits the number of tokens for prefill(a.k.a prompt processing). This needs to be tuned based on GPU memory available and request lengths. Setting this value too high can limit the number of kv cache blocks or run into OOM. If you don't set this, `vllm` will default to max model length from Hugging Face config(also accounts for rope scaling if applicable). | Default: `None` | -| option.max_model_len | \>= 0.26.0 | Pass Through | the maximum length (input+output) vLLM should preserve memory for. If not specified, will use the default length the model is capable in config.json. Sometimes model's maximum length could go to 32k (Mistral 7B) and way beyond the supported KV token size. In that case to deploy on a small instance, we need to adjust this value within the range of KV Cache limit. | Default: `None` | -| option.load_format | \>= 0.26.0 | Pass Through | The checkpoint format of the model. Default is auto and means bin/safetensors will be used if found. | Default: `auto` | -| option.enforce_eager | \>= 0.27.0 | Pass Through | vLLM by default will run with CUDA graph optimization to reach to the best performance. However, in the situation of very less GPU memory, having CUDA graph enabled will cause OOM. So if you set this option to true, we will use PyTorch Eager mode and disable CUDA graph to save some GBs of memory. | Default: `False` | -| option.gpu_memory_utilization | \>= 0.27.0 | Pass Through | This config controls the amount of GPU memory allocated to KV cache. Setting higher value will allocate more memory for KV cache.Default is 0.9. It recommended to reduce this value if GPU OOM's are encountered. | Default: `0.9` | -| option.enable_lora | \>= 0.27.0 | Pass Through | This config enables support for LoRA adapters. | Default: `false` | -| option.max_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters that can be run at once. Allocates GPU memory for those number of adapters. | Default: `4` | -| option.max_lora_rank | \>= 0.27.0 | Pass Through | This config determines the maximum rank allowed for a LoRA adapter. Set this value to maximum rank of your adapters. Setting a larger value will enable more adapters at a greater memory usage cost. | Default: `16` | -| option.lora_extra_vocab_size | \>= 0.27.0 | Pass Through | This config determines the maximum additional vocabulary that can be added through a LoRA adapter. | Default: `256` | -| option.max_cpu_loras | \>= 0.27.0 | Pass Through | This config determines the maximum number of LoRA adapters to cache in memory. All others will be evicted to disk. | Default: `None` | -| option.enable_chunked_prefill | \>= 0.29.0 | Pass Through | This config enables chunked prefill support. With chunked prefill, longer prompts will be chunked and batched with decode requests to reduce inter token latency. This option is EXPERIMENTAL and tested for llama and falcon models only. This does not work with LoRA and speculative decoding yet. | Default: `None` | -| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | Pass Through | This config allows offloading model weights into CPU to enable large model running with limited GPU memory. | Default: `0` | -| option.enable_prefix_caching | \>= 0.29.0 | Pass Through | This config allows the engine to cache the context memory and reuse to speed up inference. | Default: `False` | -| option.disable_sliding_window | \>= 0.30.0 | Pass Through | This config disables sliding window, capping to sliding window size inference. | Default: `False` | -| option.tokenizer_mode | \>= 0.30.0 | Pass Through | This config sets the tokenizer mode for vllm. When using mistral models with mistral tokenizers, you must set this to `mistral` explicitly. | Default: `auto` | +The following table lists the set of LMI configurations. +In some situations, the equivalent vLLM configuration can be used interchangeably. +Those situations will be called out specifically. + +| Item | LMI Version | vLLM alias | Example Value | Default Value | Description | +|-----------------------------------------|-------------|-------------------------------|---------------|--------------------------------------------|--------------------------------------------------------------------------------------------------------------| +| option.quantize | \>= 0.23.0 | option.quantization | awq | None | The quantization algorithm to use. See "Quantization Support" for more details | +| option.max_rolling_batch_prefill_tokens | \>= 0.24.0 | option.max_num_batched_tokens | 32768 | None | Maximum number of tokens that the engine can process in a single batch iteration (includes prefill + decode) | +| option.cpu_offload_gb_per_gpu | \>= 0.29.0 | option.cpu_offload_gb | 4 (GB) | 0 | The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. | + +In addition to the configurations specified in the table above, LMI supports all additional vLLM EngineArguments in Pass-Through mode. +Pass-Through configurations are not processed or validated by LMI. +You can find the set of EngineArguments supported by vLLM [here](https://docs.vllm.ai/en/v0.6.3.post1/models/engine_args.html#engine-args). + +You can specify these pass-through configurations in the serving.properties file by prefixing the configuration with `option.`, +or as environment variables by prefixing the configuration with `OPTION_`. + +We will consider two examples: a boolean configuration, and a string configuration. + +**Boolean Configuration** + +If you want to set the Engine Argument `enable_prefix_caching`, you can do: + +* `option.enable_prefix_caching=true` in serving.properties +* `OPTION_ENABLE_PREFIX_CACHING=true` as an environment variable + +**String configuration** + +If you want to set the Engine Argument `tokenizer_mode`, you can do: +* `option.tokenizer_mode=mistral` in serving.properties +* `OPTION_TOKENIZER_MODE=true` in serving.properties