From ac2861392a21042a9d27932e317f565b7df9447f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 8 Nov 2023 18:23:49 -0800 Subject: [PATCH] Remove references to deprecated jax.nn.normalize jax.nn.standardize is a drop-in replacement after deprecation (see https://github.com/google/jax/pull/18439) PiperOrigin-RevId: 580728570 --- flax/linen/activation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index b5d22cc460..ac3c194944 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -33,7 +33,6 @@ from jax.nn import log_sigmoid from jax.nn import log_softmax from jax.nn import logsumexp -from jax.nn import normalize from jax.nn import one_hot from jax.nn import relu from jax.nn import relu6 @@ -48,6 +47,9 @@ import jax.numpy as jnp from jax.numpy import tanh +# Normalize is a deprecated alias of standardize +normalize = standardize + # pylint: enable=unused-import