Skip to content

Commit

Permalink
add gelu_fast
Browse files Browse the repository at this point in the history
  • Loading branch information
HeegyuKim committed Oct 6, 2023
1 parent 6238285 commit 194d4f0
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)


def gelu_fast(x):
return 0.5 * x * (1.0 + jnp.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
}

Expand Down

0 comments on commit 194d4f0

Please sign in to comment.