diff --git a/flax/linen/activation.py b/flax/linen/activation.py index b5d22cc460..ac3c194944 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -33,7 +33,6 @@ from jax.nn import log_sigmoid from jax.nn import log_softmax from jax.nn import logsumexp -from jax.nn import normalize from jax.nn import one_hot from jax.nn import relu from jax.nn import relu6 @@ -48,6 +47,9 @@ import jax.numpy as jnp from jax.numpy import tanh +# Normalize is a deprecated alias of standardize +normalize = standardize + # pylint: enable=unused-import