diff --git a/transformerXL.py b/transformerXL.py index 34731fe..f44bed4 100644 --- a/transformerXL.py +++ b/transformerXL.py @@ -75,7 +75,7 @@ def __call__(self, values_keys:jnp.ndarray, queries:jnp.ndarray, pos_embed:jnp.n out = self.dense1(out_attention_n) out = nn.activation.gelu(out) #out = nn.activation.relu(out) - out = self.dense2(out_attention) + out = self.dense2(out) if(self.gating): out= self.gate2(out,jax.nn.relu(out_attention)) else: