Skip to content

Commit

Permalink
[Hardware][Neuron] Add on-device sampling support for Neuron (#8746)
Browse files Browse the repository at this point in the history
Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com>
  • Loading branch information
chongmni-aws and aws-aymahg authored Oct 4, 2024
1 parent 27302dd commit cc90419
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 13 deletions.
59 changes: 50 additions & 9 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for selecting and loading neuron models."""
import copy
import importlib
import os
from typing import Dict, List, Optional, Tuple
Expand All @@ -13,6 +14,8 @@
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)

TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
Expand All @@ -37,15 +40,18 @@

class NeuronCasualLM(nn.Module):

def __init__(
self,
config: PretrainedConfig,
) -> None:
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()

self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()

# Lazy initialized
self.model: nn.Module
Expand All @@ -71,8 +77,29 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))

return SamplerOutput(outputs=next_tokens)

def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
Expand Down Expand Up @@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization))
weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args


def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None


def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)


def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
Expand All @@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:

# Create a model instance.
model = NeuronCasualLM(model_config.hf_config)
model = NeuronCasualLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))

default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
Expand Down
82 changes: 78 additions & 4 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers_neuronx.config import GenerationConfig

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand Down Expand Up @@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict(

class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):

# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256

def __init__(
self,
model_config: ModelConfig,
Expand All @@ -76,6 +81,34 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # initialize after load_model.

# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self._on_device_sampling_disabled = int(
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))

# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self._previous_batch_request_ids: List[str] = []

if not self._on_device_sampling_disabled:
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)

def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model(
Expand Down Expand Up @@ -215,7 +248,7 @@ def prepare_model_input(
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
seq_lens = None
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
Expand All @@ -227,12 +260,49 @@ def prepare_model_input(
self.pin_memory,
generators=self.get_generators(finished_requests_ids))

if not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(sampling_metadata)
self._previous_batch_request_ids = current_batch_request_ids

return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)

def _update_neuron_sampling_params(self,
sampling_metadata: SamplingMetadata):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")

top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
for index, sequence_group_to_sample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = self._convert_to_neuron_top_k(
sequence_group_to_sample.sampling_params.top_k)
top_p[index] = sequence_group_to_sample.sampling_params.top_p
temperature[index] = \
sequence_group_to_sample.sampling_params.temperature

self.model.model.update_generation_config(current_sampling_params)

def _convert_to_neuron_top_k(self, top_k: int) -> int:
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
return self._MAX_NEURON_SAMPLING_TOP_K
return top_k

@torch.inference_mode()
def execute_model(
self,
Expand All @@ -253,9 +323,13 @@ def execute_model(
device=self.device),
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if self._on_device_sampling_disabled:
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
else:
logits = hidden_states

# Sample the next token.
output = self.model.sample(
Expand Down

0 comments on commit cc90419

Please sign in to comment.