From 8729e85bb8bccb4a33044386c878a4a6b7218027 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 22 Aug 2024 01:14:13 -0700 Subject: [PATCH] [Misc] Use torch.compile for GemmaRMSNorm (#7642) --- vllm/model_executor/layers/layernorm.py | 29 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7a8699e3932cb..e3d588efd9b6d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -114,10 +114,12 @@ def __init__( self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - def forward_native( - self, + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor], ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype @@ -127,17 +129,32 @@ def forward_native( x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) + x = x * (1.0 + weight.float()) x = x.to(orig_dtype) return x if residual is None else (x, residual) + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. + if torch.compiler.is_compiling(): + return self.forward_native(x, residual) + + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True return self.forward_native(x, residual)