Skip to content

Commit

Permalink
[Misc] Use torch.compile for GemmaRMSNorm (vllm-project#7642)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 22, 2024
1 parent 3d370fd commit 1481451
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ def __init__(
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward_native(
self,
@staticmethod
def forward_static(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
Expand All @@ -127,17 +129,32 @@ def forward_native(

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + self.weight.float())
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
if torch.compiler.is_compiling():
return self.forward_native(x, residual)

if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual)

0 comments on commit 1481451

Please sign in to comment.