Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma capping is a must for big models #31698

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
Expand All @@ -135,6 +137,7 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping

super().__init__(
pad_token_id=pad_token_id,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down
Loading