diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 10fa22264..f37369e38 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -300,7 +300,13 @@ class BatchNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, use_running_average: Optional[bool] = None, mask=None): + def __call__( + self, + x, + use_running_average: Optional[bool] = None, + *, + mask: Optional[jax.Array] = None, + ): """Normalizes the input using batch statistics. NOTE: @@ -449,7 +455,7 @@ class LayerNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies layer normalization on the input. Args: @@ -543,7 +549,7 @@ class RMSNorm(Module): axis_index_groups: Any = None @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies RMS layer normalization on the input. Args: @@ -664,7 +670,7 @@ class GroupNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: @@ -834,7 +840,7 @@ class InstanceNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies instance normalization on the input. Args: diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 9af329018..d3184e2c4 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -146,7 +146,7 @@ def test_layer_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.LayerNorm() - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) z = y.mean(-1, where=m) np.testing.assert_allclose(z, 0, atol=1e-4) @@ -163,7 +163,7 @@ def test_rms_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.RMSNorm() - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) z = np.square(y).mean(-1, where=m) np.testing.assert_allclose(z, 1, atol=1e-4) @@ -177,7 +177,7 @@ def test_group_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.GroupNorm(7, use_bias=False, use_scale=False) - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) yr = y.reshape((13, 3, 5, 7, 11)) mr = m.reshape((13, 3, 5, 7, 11))