Skip to content

Commit

Permalink
Fixed implementation error leading to only 1 MLP layer instead of 2 a…
Browse files Browse the repository at this point in the history
…nd normalization layer
  • Loading branch information
Reytuag authored Nov 2, 2024
1 parent 0f9778a commit db3e52a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformerXL.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __call__(self, values_keys:jnp.ndarray, queries:jnp.ndarray, pos_embed:jnp.n
out = self.dense1(out_attention_n)
out = nn.activation.gelu(out)
#out = nn.activation.relu(out)
out = self.dense2(out_attention)
out = self.dense2(out)
if(self.gating):
out= self.gate2(out,jax.nn.relu(out_attention))
else:
Expand Down

0 comments on commit db3e52a

Please sign in to comment.