From 1732299451f0342166f701df322714c6cd8ad880 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 2 Jul 2024 01:40:02 +0200 Subject: [PATCH] [Model] Changes to MLPSpeculator to support tie_weights and input_scale (#5965) Signed-off-by: Thomas Parnell Co-authored-by: Joshua Rosenkranz --- vllm/model_executor/models/mlp_speculator.py | 94 ++++++++++++++----- .../configs/mlp_speculator.py | 12 +++ 2 files changed, 81 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 6e6b2d8a7edb0..290a703af6ffa 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -13,6 +13,8 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig +SQRT2 = 2**0.5 + class MLPSpeculatorLayerNorm(nn.Module): """ @@ -26,24 +28,30 @@ class MLPSpeculatorLayerNorm(nn.Module): Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). + elementwise_scale_and_shift : bool + Include a learned scaling and shift term after normalization. """ def __init__( self, normalized_shape, eps=1e-06, + elementwise_scale_and_shift=True, ): super(MLPSpeculatorLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.empty(normalized_shape)) - self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.elementwise_scale_and_shift = elementwise_scale_and_shift + if self.elementwise_scale_and_shift: + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) self.eps = eps def forward(self, x): xf = x xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) x = xf.type_as(x) - x = self.weight * x - x = x + self.bias + if self.elementwise_scale_and_shift: + x = self.weight * x + x = x + self.bias return x @@ -59,27 +67,60 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: self.max_speculative_tokens = config.num_lookahead_tokens - self.emb = nn.ModuleList([ - VocabParallelEmbedding(config.vocab_size, - self.inner_dim, - org_num_embeddings=config.vocab_size) - for _ in range(self.max_speculative_tokens) - ]) - - self.proj = nn.ModuleList([ - nn.Linear((self.emb_dim if i == 0 else self.inner_dim), - self.inner_dim, - bias=False) for i in range(self.max_speculative_tokens) - ]) - - self.head = nn.ModuleList([ - nn.Linear(self.inner_dim, self.vocab_size, bias=False) - for _ in range(self.max_speculative_tokens) - ]) - self.ln = nn.ModuleList([ - MLPSpeculatorLayerNorm(self.inner_dim) - for _ in range(self.max_speculative_tokens) - ]) + self.tie_weights = config.tie_weights + self.scale_input = config.scale_input + + if self.tie_weights: + assert ( + self.n_predict > + 1), "You cannot tie weights between stages when only 1 exists" + embedding = VocabParallelEmbedding( + config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) + + # the initial projection from the base model may + # have a different size, so that stays separate. + proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) + proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) + self.proj = nn.ModuleList([proj_first] + [proj_tied] * + (self.max_speculative_tokens - 1)) + + head = nn.Linear(self.inner_dim, self.vocab_size, bias=False) + self.head = nn.ModuleList([head] * self.max_speculative_tokens) + + ln = MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) + + else: + self.emb = nn.ModuleList([ + VocabParallelEmbedding(config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + for _ in range(self.max_speculative_tokens) + ]) + + self.proj = nn.ModuleList([ + nn.Linear((self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False) + for i in range(self.max_speculative_tokens) + ]) + + self.head = nn.ModuleList([ + nn.Linear(self.inner_dim, self.vocab_size, bias=False) + for _ in range(self.max_speculative_tokens) + ]) + self.ln = nn.ModuleList([ + MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + for _ in range(self.max_speculative_tokens) + ]) + if self.scale_input: + self.ln0 = MLPSpeculatorLayerNorm( + self.emb_dim, elementwise_scale_and_shift=False) self.state_weight = 0.5**(0.5 / config.n_predict) self.emb_weight = math.sqrt( @@ -105,6 +146,9 @@ def generate_proposals( # b x 1 x d previous_hidden_states = previous_hidden_states.unsqueeze(1) + if self.scale_input: + previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + # b x 1 last_tokens = input_ids.unsqueeze(1) diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index e1c1f4a960128..946af4e919f7c 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -17,6 +17,8 @@ def __init__(self, n_predict: int = 3, top_k_tokens_per_head: Optional[List[int]] = None, n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, **kwargs): """ Initialize an MLPSpeculatorConfig @@ -38,6 +40,14 @@ def __init__(self, NOTE: This parameter is currently unused. n_candidates: int number of child candidates to create per sequence + tie_weights: bool + If true, use a single set of weights for every model + head/stage after the first. The initial projection + from the base model may have a different size, so that + stays separate. + scale_input: bool + if True, will scale the initial hidden states from + the base model. """ if top_k_tokens_per_head is None: top_k_tokens_per_head = [5, 4, 3] @@ -49,5 +59,7 @@ def __init__(self, self.top_k_tokens_per_head = top_k_tokens_per_head self.n_candidates = n_candidates self.num_lookahead_tokens = n_predict + self.tie_weights = tie_weights + self.scale_input = scale_input super().__init__(**kwargs)