Skip to content

Commit

Permalink
Remove references to deprecated jax.nn.normalize
Browse files Browse the repository at this point in the history
jax.nn.standardize is a drop-in replacement after deprecation (see jax-ml/jax#18439)

PiperOrigin-RevId: 580728570
  • Loading branch information
Jake VanderPlas authored and Flax Authors committed Nov 9, 2023
1 parent 50de4c4 commit ac28613
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax.nn import log_sigmoid
from jax.nn import log_softmax
from jax.nn import logsumexp
from jax.nn import normalize
from jax.nn import one_hot
from jax.nn import relu
from jax.nn import relu6
Expand All @@ -48,6 +47,9 @@
import jax.numpy as jnp
from jax.numpy import tanh

# Normalize is a deprecated alias of standardize
normalize = standardize

# pylint: enable=unused-import


Expand Down

0 comments on commit ac28613

Please sign in to comment.