From 1cf7593035d02cfe4ade51d7ac3ce84fb081cd33 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Fri, 26 Jan 2024 17:51:53 -0800 Subject: [PATCH] enforce mask kwarg in norm layers --- flax/linen/normalization.py | 10 +++++----- tests/linen/linen_test.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 1535a7e8f8..78f52ce88a 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -298,7 +298,7 @@ class BatchNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, use_running_average: Optional[bool] = None, mask=None): + def __call__(self, x, use_running_average: Optional[bool] = None, *, mask: Optional[jax.Array] = None): """Normalizes the input using batch statistics. NOTE: @@ -434,7 +434,7 @@ class LayerNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies layer normalization on the input. Args: @@ -528,7 +528,7 @@ class RMSNorm(Module): axis_index_groups: Any = None @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies layer normalization on the input. Args: @@ -637,7 +637,7 @@ class GroupNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: @@ -807,7 +807,7 @@ class InstanceNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x, mask=None): + def __call__(self, x, *, mask: Optional[jax.Array] = None): """Applies instance normalization on the input. Args: diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 8873507063..012532585e 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -146,7 +146,7 @@ def test_layer_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.LayerNorm() - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) z = y.mean(-1, where=m) np.testing.assert_allclose(z, 0, atol=1e-4) @@ -163,7 +163,7 @@ def test_rms_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.RMSNorm() - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) z = np.square(y).mean(-1, where=m) np.testing.assert_allclose(z, 1, atol=1e-4) @@ -177,7 +177,7 @@ def test_group_norm_mask(self): x = jnp.where(m, x, jnp.nan) module = nn.GroupNorm(7, use_bias=False, use_scale=False) - y, w = module.init_with_output(key, x, m) + y, w = module.init_with_output(key, x, mask=m) yr = y.reshape((13, 3, 5, 7, 11)) mr = m.reshape((13, 3, 5, 7, 11))