diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7a8699e3932cb..e3d588efd9b6d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -114,10 +114,12 @@ def __init__( self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - def forward_native( - self, + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor], ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype @@ -127,17 +129,32 @@ def forward_native( x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) + x = x * (1.0 + weight.float()) x = x.to(orig_dtype) return x if residual is None else (x, residual) + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. + if torch.compiler.is_compiling(): + return self.forward_native(x, residual) + + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True return self.forward_native(x, residual)