Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Model] Changes to MLPSpeculator to support tie_weights and input_sca…
Browse files Browse the repository at this point in the history
…le (vllm-project#5965)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
  • Loading branch information
2 people authored and robertgshaw2-redhat committed Jul 7, 2024
1 parent 47bc35f commit 1732299
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 25 deletions.
94 changes: 69 additions & 25 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig

SQRT2 = 2**0.5


class MLPSpeculatorLayerNorm(nn.Module):
"""
Expand All @@ -26,24 +28,30 @@ class MLPSpeculatorLayerNorm(nn.Module):
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
"""

def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_scale_and_shift=True,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
if self.elementwise_scale_and_shift:
x = self.weight * x
x = x + self.bias
return x


Expand All @@ -59,27 +67,60 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:

self.max_speculative_tokens = config.num_lookahead_tokens

self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False) for i in range(self.max_speculative_tokens)
])

self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim)
for _ in range(self.max_speculative_tokens)
])
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input

if self.tie_weights:
assert (
self.n_predict >
1), "You cannot tie weights between stages when only 1 exists"
embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)

# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))

head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)

ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)

else:
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False)
for i in range(self.max_speculative_tokens)
])

self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)

self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
Expand All @@ -105,6 +146,9 @@ def generate_proposals(
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)

if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2

# b x 1
last_tokens = input_ids.unsqueeze(1)

Expand Down
12 changes: 12 additions & 0 deletions vllm/transformers_utils/configs/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self,
n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
**kwargs):
"""
Initialize an MLPSpeculatorConfig
Expand All @@ -38,6 +40,14 @@ def __init__(self,
NOTE: This parameter is currently unused.
n_candidates: int
number of child candidates to create per sequence
tie_weights: bool
If true, use a single set of weights for every model
head/stage after the first. The initial projection
from the base model may have a different size, so that
stays separate.
scale_input: bool
if True, will scale the initial hidden states from
the base model.
"""
if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3]
Expand All @@ -49,5 +59,7 @@ def __init__(self,
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict
self.tie_weights = tie_weights
self.scale_input = scale_input

super().__init__(**kwargs)

0 comments on commit 1732299

Please sign in to comment.