diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 345919c5d1636..43ea4eb5a4d1a 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -20,6 +20,7 @@ def __init__( hidden_size: int, eps: float = 1e-6, var_hidden_size: Optional[int] = None, + has_weight: bool = True, ) -> None: super().__init__() @@ -27,7 +28,11 @@ def __init__( self.variance_epsilon = eps self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.has_weight = has_weight + + self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = nn.Parameter(self.weight) def forward_native( self, @@ -59,7 +64,9 @@ def forward_native( variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight + x = x.to(orig_dtype) + if self.has_weight: + x = x * self.weight if residual is None: return x else: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8ef0a6cdf2c52..10bec75f49fdf 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -40,6 +40,7 @@ def __init__(self, use_conv_bias: bool, use_bias: bool, use_rms_norm: bool, + rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, activation="silu"): super().__init__() @@ -105,14 +106,23 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): input_is_parallel=True, ) - self.dt_layernorm = RMSNorm(time_step_rank, - eps=rms_norm_eps) if use_rms_norm else None - - self.b_layernorm = RMSNorm(ssm_state_size, - eps=rms_norm_eps) if use_rms_norm else None - - self.c_layernorm = RMSNorm(ssm_state_size, - eps=rms_norm_eps) if use_rms_norm else None + self.dt_layernorm = RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.b_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.c_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None def forward_native(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b32032e411b0a..8bdcd2c5aad1f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -47,6 +47,7 @@ def __init__(self, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, use_rms_norm=self.is_falcon_mamba, + rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act) @@ -241,8 +242,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -254,3 +257,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params