-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ops.nn.moments
and speed-up normalization layers
#866
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #866 +/- ##
==========================================
+ Coverage 76.49% 76.56% +0.06%
==========================================
Files 329 329
Lines 31334 31422 +88
Branches 6100 6113 +13
==========================================
+ Hits 23970 24057 +87
- Misses 5785 5786 +1
Partials 1579 1579
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
Thanks for the PR, this is great! What performance changes did you observe in e.g. BatchNormalization when using |
EDITED: ORIGINAL: I have reordered the operations in # actually, tf.nn.moments could be faster...
# original
variance = jnp.mean(jnp.square(x - jax.lax.stop_gradient(mean)), axis=axes, keepdims=True)
# faster version
variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square(
jax.lax.stop_gradient(mean)
) I have also updated the Normalization layers. I observed better performance in TF & JAX but torch showed similar performance. Benchmark script: from keras_core import layers
from keras_core import models
from keras_core import ops
x_train = ops.random.uniform(shape=(1024, 224, 224, 3))
y_train = ops.random.uniform(shape=(1024, 224, 224, 3))
# layers.BatchNormalization
# layers.GroupNormalization
# layers.LayerNormalization
normalization_cls = layers.BatchNormalization
normalization_args = {}
if normalization_cls is layers.GroupNormalization:
normalization_args = {"groups": 3}
model = models.Sequential(
[
layers.InputLayer(shape=(224, 224, 3)),
normalization_cls(**normalization_args),
normalization_cls(**normalization_args),
normalization_cls(**normalization_args),
]
)
model.compile(loss="mse", optimizer="adam")
model.fit(x_train, y_train, batch_size=128, epochs=3) Results (with 1080 8gb card):
|
ops.nn.moments
ops.nn.moments
and speed-up normalization layers
Hi @fchollet
In TensorFlow
variance = math_ops.reduce_mean(
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
axes,
keepdims=True,
name="variance") In Keras Core and FlaxKeras Core and Flax (the NN library for JAX) use a faster but less stable version for computing variance # use_fast_variance=True by default in Flax
if use_fast_variance:
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
else:
mu = maybe_distributed_mean(x)
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes))) TorchI just checked the cpu version and torch should use the same approach as tensorflow Performance Comparison (TensorFlow)
Which one should we take? |
Thanks for the analysis -- what does "current implementation" refer to? The implementation in this PR? |
Sorry for the confusion. I have updated the table as follows:
References:
The question remains: Should we adopt the fast but unstable variance computation or the stable but slower version? |
My take is that until we see reports of users running into stability issues, then the fast implementation should be fine. The fact that Flax defaults to it is evidence that there's little issue. I did a quick search within the google codebase and found only a couple of usages of |
The code looks good! Do you want to include the normalization layer changes in this PR, or merge this PR first and then create another one? |
Thanks for the valuable insights about the usage of Flax.
I think the changes are already in this PR. Please let me know if I missed anything. |
Hi @fchollet Addtionally, this PR addresses the overflow and underflow issue that occur when the input is float16. (I encountered this before when using GroupNormalization with mixed_float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the great contribution!
EDITED:
This PR adds
ops.nn.moments
and improves some normalization layers by fast mean and variance computation.This PR also addresses the overflow & underflow issue when the input tensor is float16.
There are two approaches to compute variance:
ops.nn.moments
(fast but unstable, this PR)ops.nn.moments
(stable but slower)References:
tf.nn.moments
using stable but slower version (link)use_fast_variance
(link)