Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.nn.moments and speed-up normalization layers #866

Merged
merged 5 commits into from
Sep 14, 2023

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 12, 2023

EDITED:
This PR adds ops.nn.moments and improves some normalization layers by fast mean and variance computation.
This PR also addresses the overflow & underflow issue when the input tensor is float16.

There are two approaches to compute variance:

  • $Var = E[|x - E[x]|^2])$ (stable but slower)
  • $Var = E[|x|^2] - |E[x]|^2$ (fast but unstable)
backend layer manual implementation (before this PR) ops.nn.moments (fast but unstable, this PR) ops.nn.moments (stable but slower)
tensorflow BatchNormalization 58ms 58ms 69ms
jax BatchNormalization 63ms 63ms 75ms
torch BatchNormalization 73ms 72ms 74ms
tensorflow GroupNormalization 89ms 61ms 73ms
jax GroupNormalization 96ms 72ms 83ms
torch GroupNormalization 72ms 74ms 76ms
tensorflow LayerNormalization 52ms 47ms 48ms
jax LayerNormalization 68ms 59ms 60ms
torch LayerNormalization 88ms 90ms 91ms

References:

  • tensorflow: tf.nn.moments using stable but slower version (link)
  • flax: defaults to use_fast_variance (link)
  • torch: should be stable but slower version (link)

@codecov
Copy link

codecov bot commented Sep 12, 2023

Codecov Report

Patch coverage: 98.94% and project coverage change: +0.06% 🎉

Comparison is base (bb21710) 76.49% compared to head (8f06862) 76.56%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #866      +/-   ##
==========================================
+ Coverage   76.49%   76.56%   +0.06%     
==========================================
  Files         329      329              
  Lines       31334    31422      +88     
  Branches     6100     6113      +13     
==========================================
+ Hits        23970    24057      +87     
- Misses       5785     5786       +1     
  Partials     1579     1579              
Flag Coverage Δ
keras_core 76.46% <98.94%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/ops/nn.py 89.69% <93.33%> (+0.13%) ⬆️
keras_core/backend/jax/nn.py 94.00% <100.00%> (+0.59%) ⬆️
keras_core/backend/numpy/nn.py 93.10% <100.00%> (+0.71%) ⬆️
keras_core/backend/tensorflow/nn.py 82.10% <100.00%> (+1.20%) ⬆️
keras_core/backend/torch/nn.py 91.77% <100.00%> (+0.54%) ⬆️
...s_core/layers/normalization/batch_normalization.py 100.00% <100.00%> (ø)
...s_core/layers/normalization/group_normalization.py 89.01% <100.00%> (-0.12%) ⬇️
...s_core/layers/normalization/layer_normalization.py 97.40% <100.00%> (+0.03%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Contributor

Thanks for the PR, this is great! What performance changes did you observe in e.g. BatchNormalization when using moments in different backends compared to the existing manual implementation?

@james77777778
Copy link
Contributor Author

james77777778 commented Sep 13, 2023

Thanks for the PR, this is great! What performance changes did you observe in e.g. BatchNormalization when using moments in different backends compared to the existing manual implementation?

EDITED:
Please see the newest comment and let me know which implementation should be taken.

ORIGINAL:
@fchollet

I have reordered the operations in ops.nn.moments to get a significant speed-up. The key is to downsize the tensor before performing element-wise operations (such as jnp.subtract)

# actually, tf.nn.moments could be faster...

# original
variance = jnp.mean(jnp.square(x - jax.lax.stop_gradient(mean)), axis=axes, keepdims=True)

# faster version
variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square(
    jax.lax.stop_gradient(mean)
)

I have also updated the Normalization layers. I observed better performance in TF & JAX but torch showed similar performance.

Benchmark script:

from keras_core import layers
from keras_core import models
from keras_core import ops

x_train = ops.random.uniform(shape=(1024, 224, 224, 3))
y_train = ops.random.uniform(shape=(1024, 224, 224, 3))

# layers.BatchNormalization
# layers.GroupNormalization
# layers.LayerNormalization
normalization_cls = layers.BatchNormalization
normalization_args = {}
if normalization_cls is layers.GroupNormalization:
    normalization_args = {"groups": 3}

model = models.Sequential(
    [
        layers.InputLayer(shape=(224, 224, 3)),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
    ]
)
model.compile(loss="mse", optimizer="adam")
model.fit(x_train, y_train, batch_size=128, epochs=3)

Results (with 1080 8gb card):

backend layer current implementation using ops.nn.moments notes
tensorflow BatchNormalization 58ms 58ms fair
jax BatchNormalization 63ms 63ms fair
torch BatchNormalization 73ms 72ms fair
tensorflow GroupNormalization 89ms 61ms ⬇️
jax GroupNormalization 96ms 72ms ⬇️
torch GroupNormalization 72ms 74ms fair
tensorflow LayerNormalization 52ms 47ms ⬇️
jax LayerNormalization 68ms 59ms ⬇️
torch LayerNormalization 88ms 90ms fair

@james77777778 james77777778 changed the title Add ops.nn.moments Add ops.nn.moments and speed-up normalization layers Sep 13, 2023
@james77777778
Copy link
Contributor Author

james77777778 commented Sep 13, 2023

Hi @fchollet
I have a question about the choice of implementation:

  • $Var = E[|x - E[x]|^2])$
  • $Var = E[|x|^2] - |E[x]|^2$

In TensorFlow

tf.nn.moments uses a slower but more numerically stable version to compute variance

https://github.com/tensorflow/tensorflow/blob/3d1802023778a164d35c79536990b35b701e8018/tensorflow/python/ops/nn_impl.py#L1264C5-L1268C25

        variance = math_ops.reduce_mean(
        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
        axes,
        keepdims=True,
        name="variance")

In Keras Core and Flax

Keras Core and Flax (the NN library for JAX) use a faster but less stable version for computing variance

https://github.com/google/flax/blob/ca3ea06f78834137dfb49dc6c1a0c26fb962003a/flax/linen/normalization.py#L108-L120

    # use_fast_variance=True by default in Flax
    if use_fast_variance:
      mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
      # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
      # to floating point round-off errors.
      var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
    else:
      mu = maybe_distributed_mean(x)
      var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)))

Torch

I just checked the cpu version and torch should use the same approach as tensorflow

https://github.com/pytorch/pytorch/blob/48e6ffbe308e915b67c5b4f9532f794d6706c903/aten/src/ATen/native/cpu/batch_norm_kernel.cpp#L200-L209

Performance Comparison (TensorFlow)

backend layer current implementation fast variance computation tf.nn.moments
tensorflow BatchNormalization 58ms 58ms 69ms
tensorflow GroupNormalization 89ms 61ms 73ms
tensorflow LayerNormalization 52ms 47ms 48ms

Which one should we take?

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Sep 13, 2023
@fchollet
Copy link
Contributor

Thanks for the analysis -- what does "current implementation" refer to? The implementation in this PR?

@james77777778
Copy link
Contributor Author

james77777778 commented Sep 14, 2023

Thanks for the analysis -- what does "current implementation" refer to? The implementation in this PR?

Sorry for the confusion. I have updated the table as follows:

  • $Var = E[|x - E[x]|^2])$ (stable but slower)
  • $Var = E[|x|^2] - |E[x]|^2$ (fast but unstable)
backend layer manual implementation (before this PR) ops.nn.moments (fast but unstable variance computation, this PR's implementation) ops.nn.moments (stable but slower variance computation)
tensorflow BatchNormalization 58ms 58ms 69ms
jax BatchNormalization 63ms 63ms 75ms
torch BatchNormalization 73ms 72ms 74ms
tensorflow GroupNormalization 89ms 61ms 73ms
jax GroupNormalization 96ms 72ms 83ms
torch GroupNormalization 72ms 74ms 76ms
tensorflow LayerNormalization 52ms 47ms 48ms
jax LayerNormalization 68ms 59ms 60ms
torch LayerNormalization 88ms 90ms 91ms

References:

  • tensorflow: tf.nn.moments using stable but slower version (link)
  • flax: defaults to use_fast_variance (link)
  • torch: should be stable but slower version (link)

The question remains: Should we adopt the fast but unstable variance computation or the stable but slower version?

@fchollet
Copy link
Contributor

My take is that until we see reports of users running into stability issues, then the fast implementation should be fine. The fact that Flax defaults to it is evidence that there's little issue. I did a quick search within the google codebase and found only a couple of usages of use_fast_variance=False, among thousands of usages of Flax normalization layers (they aren't commented, so unclear why they went with False). So it seems that in practice the problem doesn't really surface.

@fchollet
Copy link
Contributor

The code looks good! Do you want to include the normalization layer changes in this PR, or merge this PR first and then create another one?

@james77777778
Copy link
Contributor Author

My take is that until we see reports of users running into stability issues, then the fast implementation should be fine. The fact that Flax defaults to it is evidence that there's little issue. I did a quick search within the google codebase and found only a couple of usages of use_fast_variance=False, among thousands of usages of Flax normalization layers (they aren't commented, so unclear why they went with False). So it seems that in practice the problem doesn't really surface.

Thanks for the valuable insights about the usage of Flax.
After a lot of searching, I can only find this comment defending the numerically stable computation of variance (without evidence?) tensorflow/tensorflow#4198 (comment)

The code looks good! Do you want to include the normalization layer changes in this PR, or merge this PR first and then create another one?

I think the changes are already in this PR. Please let me know if I missed anything.

@james77777778
Copy link
Contributor Author

Hi @fchollet
This PR should be ready. We now have fast mean & variance computation using ops.nn.moments, and it is applied to BatchNormalization, GroupNormalization and LayerNormalization to achieve some speed-ups.

Addtionally, this PR addresses the overflow and underflow issue that occur when the input is float16. (I encountered this before when using GroupNormalization with mixed_float16)
Credits to @fsx950223 (tensorflow/tensorflow#52217)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the great contribution!

@fchollet fchollet merged commit e8db3b6 into keras-team:main Sep 14, 2023
@james77777778 james77777778 deleted the add-moments branch September 15, 2023 00:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants