Skip to content

Commit

Permalink
[Neuron] Adding support for context-lenght, token-gen buckets. (vllm-…
Browse files Browse the repository at this point in the history
…project#7885)

Co-authored-by: Harsha Bikki <harbikh@amazon.com>
  • Loading branch information
hbikki and Harsha Bikki authored Aug 29, 2024
1 parent 86a677d commit 257afc3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
11 changes: 9 additions & 2 deletions examples/offline_inference_neuron.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import os

from vllm import LLM, SamplingParams

# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"

# Sample prompts.
prompts = [
"Hello, my name is",
Expand All @@ -19,8 +26,8 @@
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
Expand Down
33 changes: 24 additions & 9 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")


def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list


def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
Expand All @@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)

context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])

# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)

return model.eval()

0 comments on commit 257afc3

Please sign in to comment.