diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 594ae442ef328..00c82fb77186c 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,4 +1,5 @@ """Utilities for selecting and loading neuron models.""" +import copy import importlib import os from typing import Dict, List, Optional, Tuple @@ -13,6 +14,8 @@ from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", @@ -37,15 +40,18 @@ class NeuronCasualLM(nn.Module): - def __init__( - self, - config: PretrainedConfig, - ) -> None: + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: super().__init__() self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, logits_as_input=True) - self.sampler = Sampler() + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() # Lazy initialized self.model: nn.Module @@ -71,8 +77,29 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + + if self.on_device_sampling_disabled: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + # On-device sampling outputs the token ids directly. + sampled_token_ids = logits.flatten() + next_tokens = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + samples = [] + for seq_id in seq_group.seq_ids: + token_id = sampled_token_ids[sample_idx].item() + samples.append( + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + next_tokens.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + + return SamplerOutput(outputs=next_tokens) def load_weights(self, model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) @@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig, quant=neuron_quantization_config_builder(model_config.quantization) if model_config.quantization else None, continuous_batching=continuous_batching_config, - weight_tiling=bool(model_config.quantization)) + weight_tiling=bool(model_config.quantization), + on_device_generation=_get_neuron_on_device_generation_config( + model_config)) return default_neuron_args +def _get_neuron_on_device_generation_config(model_config: ModelConfig): + if not _is_neuron_on_device_sampling_disabled(model_config): + return copy.deepcopy(model_config.neuron_sampling_params) + return None + + +def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: + return not getattr(model_config, "neuron_sampling_params", None) + + def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): from transformers_neuronx.config import NeuronConfig @@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig, scheduler_config: SchedulerConfig) -> nn.Module: # Create a model instance. - model = NeuronCasualLM(model_config.hf_config) + model = NeuronCasualLM( + model_config.hf_config, + _is_neuron_on_device_sampling_disabled(model_config)) default_neuron_config_args = _get_default_neuron_config( model_config, parallel_config, scheduler_config) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 0cf7445d4388d..44d4845a838ef 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,9 +1,11 @@ +import os from dataclasses import dataclass from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn +from transformers_neuronx.config import GenerationConfig from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict( class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + # NEURON has an upper limit on the top_k + _MAX_NEURON_SAMPLING_TOP_K = 256 + def __init__( self, model_config: ModelConfig, @@ -76,6 +81,34 @@ def __init__( # Lazy initialization. self.model: nn.Module # initialize after load_model. + # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, + # turn off on-device sampling. + self._on_device_sampling_disabled = int( + os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) + + # NEURON needs to update sampling parameters when request IDs change + # across batches. This variable stores the previous batch's request IDs + # to determine if an update is needed. + self._previous_batch_request_ids: List[str] = [] + + if not self._on_device_sampling_disabled: + logger.warning( + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." + ) + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: self.model = get_neuron_model( @@ -215,7 +248,7 @@ def prepare_model_input( else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] + seq_lens = None sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -227,12 +260,49 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) + if not self._on_device_sampling_disabled: + # Once the request IDs are changed in current iteration, we will + # update the on-device sampling parameters. + current_batch_request_ids = [ + seq_group_meta_data.request_id + for seq_group_meta_data in seq_group_metadata_list + ] + if current_batch_request_ids != self._previous_batch_request_ids: + self._update_neuron_sampling_params(sampling_metadata) + self._previous_batch_request_ids = current_batch_request_ids + return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, input_block_ids=input_block_ids, sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs) + def _update_neuron_sampling_params(self, + sampling_metadata: SamplingMetadata): + # Update Neuron sampling parameters (GenerationConfig in Neuron) + current_sampling_params = self.model_config.neuron_sampling_params + assert current_sampling_params is not None, ( + f"Failed to update sampling_params, " + f"current sampling params is {current_sampling_params}") + + top_k = current_sampling_params.top_k + top_p = current_sampling_params.top_p + temperature = current_sampling_params.temperature + for index, sequence_group_to_sample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = self._convert_to_neuron_top_k( + sequence_group_to_sample.sampling_params.top_k) + top_p[index] = sequence_group_to_sample.sampling_params.top_p + temperature[index] = \ + sequence_group_to_sample.sampling_params.temperature + + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_top_k(self, top_k: int) -> int: + if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: + return self._MAX_NEURON_SAMPLING_TOP_K + return top_k + @torch.inference_mode() def execute_model( self, @@ -253,9 +323,13 @@ def execute_model( device=self.device), ) - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) + # Compute the logits only if the on-device sampling is turned off as + # on-device sampling outputs the token ids. + if self._on_device_sampling_disabled: + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + else: + logits = hidden_states # Sample the next token. output = self.model.sample(