Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Neuron] Add on-device sampling support for Neuron #8746

Merged
merged 12 commits into from
Oct 4, 2024
59 changes: 50 additions & 9 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for selecting and loading neuron models."""
import copy
import importlib
import os
from typing import Dict, List, Optional, Tuple
Expand All @@ -13,6 +14,8 @@
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)

TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
Expand All @@ -37,15 +40,18 @@

class NeuronCasualLM(nn.Module):

def __init__(
self,
config: PretrainedConfig,
) -> None:
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()

self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()

# Lazy initialized
self.model: nn.Module
Expand All @@ -71,8 +77,29 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
chongmni-aws marked this conversation as resolved.
Show resolved Hide resolved
return next_tokens

if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))

return SamplerOutput(outputs=next_tokens)

def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
Expand Down Expand Up @@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization))
weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args


def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None


def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)


def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
Expand All @@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:

# Create a model instance.
model = NeuronCasualLM(model_config.hf_config)
model = NeuronCasualLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))

default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
Expand Down
82 changes: 78 additions & 4 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers_neuronx.config import GenerationConfig

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand Down Expand Up @@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict(

class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):

# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256

def __init__(
self,
model_config: ModelConfig,
Expand All @@ -76,6 +81,34 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # initialize after load_model.

# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self._on_device_sampling_disabled = int(
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))

# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self._previous_batch_request_ids: List[str] = []
chongmni-aws marked this conversation as resolved.
Show resolved Hide resolved

if not self._on_device_sampling_disabled:
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)

def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model(
Expand Down Expand Up @@ -215,7 +248,7 @@ def prepare_model_input(
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
seq_lens = None
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
Expand All @@ -227,12 +260,49 @@ def prepare_model_input(
self.pin_memory,
generators=self.get_generators(finished_requests_ids))

if not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(sampling_metadata)
self._previous_batch_request_ids = current_batch_request_ids

return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)

def _update_neuron_sampling_params(self,
sampling_metadata: SamplingMetadata):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")

top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
for index, sequence_group_to_sample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = self._convert_to_neuron_top_k(
sequence_group_to_sample.sampling_params.top_k)
top_p[index] = sequence_group_to_sample.sampling_params.top_p
temperature[index] = \
sequence_group_to_sample.sampling_params.temperature
chongmni-aws marked this conversation as resolved.
Show resolved Hide resolved

self.model.model.update_generation_config(current_sampling_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also turn it off when the processing sequence group have diff top p , toke k and temp values ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline , we need to look at cases on how to handle for sampling param's with n>1


def _convert_to_neuron_top_k(self, top_k: int) -> int:
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
return self._MAX_NEURON_SAMPLING_TOP_K
return top_k

@torch.inference_mode()
def execute_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in execute , can we ignore calling logits processing also when sampling is enabled ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack!

self,
Expand All @@ -253,9 +323,13 @@ def execute_model(
device=self.device),
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if self._on_device_sampling_disabled:
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
else:
logits = hidden_states

# Sample the next token.
output = self.model.sample(
Expand Down
Loading