Skip to content

Commit

Permalink
Merge pull request #3663 from chiamp:mask
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605467792
  • Loading branch information
Flax Authors committed Feb 9, 2024
2 parents 37fe9c9 + 1cf7593 commit 0a88cfd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
16 changes: 11 additions & 5 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,13 @@ class BatchNorm(Module):
use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None, mask=None):
def __call__(
self,
x,
use_running_average: Optional[bool] = None,
*,
mask: Optional[jax.Array] = None,
):
"""Normalizes the input using batch statistics.
NOTE:
Expand Down Expand Up @@ -449,7 +455,7 @@ class LayerNorm(Module):
use_fast_variance: bool = True

@compact
def __call__(self, x, mask=None):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
"""Applies layer normalization on the input.
Args:
Expand Down Expand Up @@ -543,7 +549,7 @@ class RMSNorm(Module):
axis_index_groups: Any = None

@compact
def __call__(self, x, mask=None):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
"""Applies RMS layer normalization on the input.
Args:
Expand Down Expand Up @@ -664,7 +670,7 @@ class GroupNorm(Module):
use_fast_variance: bool = True

@compact
def __call__(self, x, mask=None):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
Args:
Expand Down Expand Up @@ -834,7 +840,7 @@ class InstanceNorm(Module):
use_fast_variance: bool = True

@compact
def __call__(self, x, mask=None):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
"""Applies instance normalization on the input.
Args:
Expand Down
6 changes: 3 additions & 3 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_layer_norm_mask(self):
x = jnp.where(m, x, jnp.nan)

module = nn.LayerNorm()
y, w = module.init_with_output(key, x, m)
y, w = module.init_with_output(key, x, mask=m)

z = y.mean(-1, where=m)
np.testing.assert_allclose(z, 0, atol=1e-4)
Expand All @@ -163,7 +163,7 @@ def test_rms_norm_mask(self):
x = jnp.where(m, x, jnp.nan)

module = nn.RMSNorm()
y, w = module.init_with_output(key, x, m)
y, w = module.init_with_output(key, x, mask=m)

z = np.square(y).mean(-1, where=m)
np.testing.assert_allclose(z, 1, atol=1e-4)
Expand All @@ -177,7 +177,7 @@ def test_group_norm_mask(self):
x = jnp.where(m, x, jnp.nan)

module = nn.GroupNorm(7, use_bias=False, use_scale=False)
y, w = module.init_with_output(key, x, m)
y, w = module.init_with_output(key, x, mask=m)

yr = y.reshape((13, 3, 5, 7, 11))
mr = m.reshape((13, 3, 5, 7, 11))
Expand Down

0 comments on commit 0a88cfd

Please sign in to comment.