From a45b7aad39a080f9b155549f060c8f03ec35c102 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Tue, 12 Sep 2023 02:24:44 +0000 Subject: [PATCH 1/4] Add `ops.nn.moments` --- keras_core/backend/jax/nn.py | 33 ++++++++++++++++ keras_core/backend/numpy/nn.py | 27 +++++++++++++ keras_core/backend/tensorflow/nn.py | 4 ++ keras_core/backend/torch/nn.py | 34 +++++++++++++++++ keras_core/ops/nn.py | 59 +++++++++++++++++++++++++++++ keras_core/ops/nn_test.py | 57 ++++++++++++++++++++++++++++ 6 files changed, 214 insertions(+) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index c0e381d34..177ff007d 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -5,6 +5,7 @@ from jax import nn as jnn from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -486,3 +487,35 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * jnp.log(output) bce += (1.0 - target) * jnp.log(1.0 - output) return -bce + + +def moments(x, axes, keepdims=False): + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + # Compute true mean while keeping the dims for proper broadcasting + mean = jnp.mean(x, axes, keepdims=True) + + # Sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero + variance = jnp.mean( + jnp.square(x - jax.lax.stop_gradient(mean)), + axis=axes, + keepdims=True, + ) + + if not keepdims: + mean = jnp.squeeze(mean, axes) + variance = jnp.squeeze(variance, axes) + if need_cast: + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/numpy/nn.py b/keras_core/backend/numpy/nn.py index b62521347..5aa2210f5 100644 --- a/keras_core/backend/numpy/nn.py +++ b/keras_core/backend/numpy/nn.py @@ -4,6 +4,7 @@ from jax import numpy as jnp from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -519,3 +520,29 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * np.log(output) bce += (1.0 - target) * np.log(1.0 - output) return -bce + + +def moments(x, axes, keepdims=False): + axes = tuple(axes) if isinstance(axes, list) else axes + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + # Compute true mean while keeping the dims for proper broadcasting + mean = np.mean(x, axes, keepdims=True) + + # Sample variance, not unbiased variance + variance = np.mean(np.square(x - mean), axis=axes, keepdims=True) + + if not keepdims: + mean = np.squeeze(mean, axes) + variance = np.squeeze(variance, axes) + if need_cast: + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index e0d5454df..5b2912bd1 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -646,3 +646,7 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * tf.math.log(output) bce += (1 - target) * tf.math.log(1 - output) return -bce + + +def moments(x, axes, keepdims=False): + return tf.nn.moments(x, axes, keepdims=keepdims) diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index cb7cff213..a79288489 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -3,6 +3,7 @@ import torch.nn.functional as tnn from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_torch, ) @@ -609,3 +610,36 @@ def binary_crossentropy(target, output, from_logits=False): else: output = torch.clip(output, epsilon(), 1.0 - epsilon()) return tnn.binary_cross_entropy(output, target, reduction="none") + + +def moments(x, axes, keepdims=False): + x = convert_to_tensor(x) + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + # Compute true mean while keeping the dims for proper broadcasting + mean = torch.mean(x, dim=axes, keepdim=True) + + # Sample variance, not unbiased variance + # Note: detach does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero + variance = torch.mean( + torch.square(x - mean.detach()), + dim=axes, + keepdim=True, + ) + + if not keepdims: + mean = torch.squeeze(mean, axes) + variance = torch.squeeze(variance, axes) + if need_cast: + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/ops/nn.py b/keras_core/ops/nn.py index 9b2f74e08..218bd4c09 100644 --- a/keras_core/ops/nn.py +++ b/keras_core/ops/nn.py @@ -10,6 +10,7 @@ ) from keras_core.ops import operation_utils from keras_core.ops.operation import Operation +from keras_core.ops.operation_utils import reduce_shape class Relu(Operation): @@ -1563,3 +1564,61 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None): return MultiHot(num_tokens, axis, dtype).symbolic_call(inputs) return backend.nn.multi_hot(inputs, num_tokens, axis, dtype) + + +class Moments(Operation): + def __init__(self, axes, keepdims=False, name=None): + super().__init__(name) + self.axes = axes + self.keepdims = keepdims + + def call(self, x): + return backend.nn.moments(x, axes=self.axes, keepdims=self.keepdims) + + def compute_output_spec(self, x): + return ( + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + ) + + +@keras_core_export( + [ + "keras_core.ops.moments", + "keras_core.ops.nn.moments", + ] +) +def moments(x, axes, keepdims=False): + """Calculates the mean and variance of `x`. + + The mean and variance are calculated by aggregating the contents of `x` + across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and + variance of a vector. + + Args: + x: Input tensor. + axes: A list of axes which to compute mean and variance. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + A tuple containing two tensors - mean and variance. + + + Example: + + >>> x = keras_core.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32") + >>> keras_core.ops.moments(x, axes=[0]) + (array(21.2, dtype=float32), array(1553.3601, dtype=float32)) + + """ + if any_symbolic_tensors((x,)): + return Moments(axes, keepdims).symbolic_call(x) + + return backend.nn.moments(x, axes, keepdims) diff --git a/keras_core/ops/nn_test.py b/keras_core/ops/nn_test.py index f15412d59..efc373631 100644 --- a/keras_core/ops/nn_test.py +++ b/keras_core/ops/nn_test.py @@ -298,6 +298,20 @@ def test_one_hot_dtype(self, dtype): out = knn.one_hot(x, 5, axis=0, dtype=dtype) self.assertEqual(backend.standardize_dtype(out.dtype), dtype) + def test_moments(self): + x = KerasTensor([None, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + + self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4)) + self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,)) + self.assertEqual( + knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1) + ) + class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -591,6 +605,14 @@ def test_sparse_categorical_crossentropy(self): knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3) ) + def test_moments(self): + x = KerasTensor([2, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): @@ -1156,3 +1178,38 @@ def test_multi_hot(self): indices_1d = np.array([0, -1, -1, 3]) expected_output_1d = np.array([1, 0, 0, 1]) self.assertAllClose(knn.multi_hot(indices_1d, 4), expected_output_1d) + + def test_moments(self): + # Test 1D moments + x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x)) + self.assertAllClose(variance, np.var(x)) + + # Test batch statics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x, axis=0)) + self.assertAllClose(variance, np.var(x, axis=0)) + + # Test global statics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0, 1, 2]) + self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2))) + self.assertAllClose(variance, np.var(x, axis=(0, 1, 2))) + + # Test keepdims + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True) + self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True)) + self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True)) + + # Test float16 + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float16) + mean, variance = knn.moments(x, axes=[0]) + expected_mean = np.mean(x.astype(np.float32), axis=0).astype(np.float16) + expected_variance = np.var(x.astype(np.float32), axis=0).astype( + np.float16 + ) + self.assertAllClose(mean, expected_mean) + self.assertAllClose(variance, expected_variance) From 17d2fbc137a1aa4ec971ca4e8293df384454d776 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Tue, 12 Sep 2023 03:02:33 +0000 Subject: [PATCH 2/4] Delete empty lines --- keras_core/ops/nn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_core/ops/nn.py b/keras_core/ops/nn.py index 218bd4c09..05dbba4fa 100644 --- a/keras_core/ops/nn.py +++ b/keras_core/ops/nn.py @@ -1552,7 +1552,6 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None): Returns: Tensor: The multi-hot encoded tensor. - Example: >>> data = keras_core.ops.convert_to_tensor([0, 4]) @@ -1610,7 +1609,6 @@ def moments(x, axes, keepdims=False): Returns: A tuple containing two tensors - mean and variance. - Example: >>> x = keras_core.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32") From 99d66577d6510ef6861028bbb848c5f42d5da613 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Wed, 13 Sep 2023 01:31:25 +0000 Subject: [PATCH 3/4] Improve performance --- keras_core/backend/jax/nn.py | 6 ++-- keras_core/backend/numpy/nn.py | 2 +- keras_core/backend/tensorflow/nn.py | 29 ++++++++++++++++++- keras_core/backend/torch/nn.py | 6 ++-- .../normalization/batch_normalization.py | 7 ++--- .../normalization/group_normalization.py | 7 ++--- .../normalization/layer_normalization.py | 14 ++++----- 7 files changed, 44 insertions(+), 27 deletions(-) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 177ff007d..7832f2550 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -506,10 +506,8 @@ def moments(x, axes, keepdims=False): # Note: stop_gradient does not change the gradient that gets # backpropagated to the mean from the variance calculation, # because that gradient is zero - variance = jnp.mean( - jnp.square(x - jax.lax.stop_gradient(mean)), - axis=axes, - keepdims=True, + variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square( + jax.lax.stop_gradient(mean) ) if not keepdims: diff --git a/keras_core/backend/numpy/nn.py b/keras_core/backend/numpy/nn.py index 5aa2210f5..abc3a092f 100644 --- a/keras_core/backend/numpy/nn.py +++ b/keras_core/backend/numpy/nn.py @@ -537,7 +537,7 @@ def moments(x, axes, keepdims=False): mean = np.mean(x, axes, keepdims=True) # Sample variance, not unbiased variance - variance = np.mean(np.square(x - mean), axis=axes, keepdims=True) + variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) if not keepdims: mean = np.squeeze(mean, axes) diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index 5b2912bd1..20f888910 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -3,6 +3,7 @@ import tensorflow as tf from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_output_shape, ) @@ -649,4 +650,30 @@ def binary_crossentropy(target, output, from_logits=False): def moments(x, axes, keepdims=False): - return tf.nn.moments(x, axes, keepdims=keepdims) + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + # Compute true mean while keeping the dims for proper broadcasting + mean = tf.reduce_mean(x, axes, keepdims=True) + + # Sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero + variance = tf.reduce_mean( + tf.square(x), axis=axes, keepdims=True + ) - tf.square(tf.stop_gradient(mean)) + + if not keepdims: + mean = tf.squeeze(mean, axes) + variance = tf.squeeze(variance, axes) + if need_cast: + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index a79288489..cf288245b 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -631,10 +631,8 @@ def moments(x, axes, keepdims=False): # backpropagated to the mean from the variance calculation, # because that gradient is zero variance = torch.mean( - torch.square(x - mean.detach()), - dim=axes, - keepdim=True, - ) + torch.square(x), dim=axes, keepdim=True + ) - torch.square(mean.detach()) if not keepdims: mean = torch.squeeze(mean, axes) diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index a447313c2..0812f6843 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -198,10 +198,9 @@ def call(self, inputs, training=None, mask=None): broadcast_shape = [1] * len(inputs.shape) broadcast_shape[self.axis] = inputs.shape[self.axis] if training and self.trainable: - mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True) - variance = ops.mean( - ops.square(inputs), axis=self._reduction_axes, keepdims=True - ) - ops.square(mean) + mean, variance = ops.moments( + inputs, axes=self._reduction_axes, keepdims=True + ) outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon) mean = ops.squeeze(mean, self._reduction_axes) variance = ops.squeeze(variance, self._reduction_axes) diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py index 1fed1abc5..94b56b05f 100644 --- a/keras_core/layers/normalization/group_normalization.py +++ b/keras_core/layers/normalization/group_normalization.py @@ -171,11 +171,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape): axis = -2 if self.axis == -1 else self.axis - 1 group_reduction_axes.pop(axis) - mean = ops.mean( - reshaped_inputs, axis=group_reduction_axes, keepdims=True - ) - variance = ops.var( - reshaped_inputs, axis=group_reduction_axes, keepdims=True + mean, variance = ops.moments( + reshaped_inputs, axes=group_reduction_axes, keepdims=True ) gamma, beta = self._get_reshaped_weights(input_shape) diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index fe0bc46df..7d8381c27 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -203,19 +203,17 @@ def _broadcast(v): # this is at least as numerically stable as the fused version. inputs = ops.cast(inputs, "float32") - # Calculate the variance last axis (layer activations). - variance = ops.var(inputs, axis=self.axis, keepdims=True) - - # Compute the batch normalization. - inv = 1 / ops.sqrt(variance + self.epsilon) - if self.rms_scaling: # Calculate outputs with only variance and gamma if rms scaling # is enabled + # Calculate the variance along last axis (layer activations). + variance = ops.var(inputs, axis=self.axis, keepdims=True) + inv = 1 / ops.sqrt(variance + self.epsilon) outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma else: - # Calculate the mean last axis (layer activations). - mean = ops.mean(inputs, axis=self.axis, keepdims=True) + # Calculate the mean & variance along last axis (layer activations). + mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) + inv = 1 / ops.sqrt(variance + self.epsilon) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: scale = ops.cast(scale, inputs.dtype) From 8f06862bd09ff5439da4485cb8d46ebd425cadc7 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Thu, 14 Sep 2023 07:45:01 +0000 Subject: [PATCH 4/4] Address the issue of overflow and underflow when casting --- keras_core/backend/jax/nn.py | 22 ++++++++++++++-------- keras_core/backend/numpy/nn.py | 15 ++++++++++----- keras_core/backend/tensorflow/nn.py | 18 ++++++++++-------- keras_core/backend/torch/nn.py | 26 ++++++++++++++++++-------- keras_core/ops/nn_test.py | 19 +++++++++++-------- 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 7832f2550..7fff15ada 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -490,22 +490,21 @@ def binary_crossentropy(target, output, from_logits=False): def moments(x, axes, keepdims=False): - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 need_cast = False ori_dtype = standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") - # Compute true mean while keeping the dims for proper broadcasting mean = jnp.mean(x, axes, keepdims=True) - # Sample variance, not unbiased variance - # Note: stop_gradient does not change the gradient that gets - # backpropagated to the mean from the variance calculation, - # because that gradient is zero + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square( jax.lax.stop_gradient(mean) ) @@ -514,6 +513,13 @@ def moments(x, axes, keepdims=False): mean = jnp.squeeze(mean, axes) variance = jnp.squeeze(variance, axes) if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = jnp.clip( + mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) + variance = jnp.clip( + variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) mean = cast(mean, ori_dtype) variance = cast(variance, ori_dtype) return mean, variance diff --git a/keras_core/backend/numpy/nn.py b/keras_core/backend/numpy/nn.py index abc3a092f..231575874 100644 --- a/keras_core/backend/numpy/nn.py +++ b/keras_core/backend/numpy/nn.py @@ -524,25 +524,30 @@ def binary_crossentropy(target, output, from_logits=False): def moments(x, axes, keepdims=False): axes = tuple(axes) if isinstance(axes, list) else axes - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 need_cast = False ori_dtype = standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") - # Compute true mean while keeping the dims for proper broadcasting mean = np.mean(x, axes, keepdims=True) - # Sample variance, not unbiased variance + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) if not keepdims: mean = np.squeeze(mean, axes) variance = np.squeeze(variance, axes) if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) + variance = np.clip( + variance, np.finfo(np.float16).min, np.finfo(np.float16).max + ) mean = cast(mean, ori_dtype) variance = cast(variance, ori_dtype) return mean, variance diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index 20f888910..e3eb8f6ea 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -650,22 +650,21 @@ def binary_crossentropy(target, output, from_logits=False): def moments(x, axes, keepdims=False): - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 need_cast = False ori_dtype = standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") - # Compute true mean while keeping the dims for proper broadcasting mean = tf.reduce_mean(x, axes, keepdims=True) - # Sample variance, not unbiased variance - # Note: stop_gradient does not change the gradient that gets - # backpropagated to the mean from the variance calculation, - # because that gradient is zero + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. variance = tf.reduce_mean( tf.square(x), axis=axes, keepdims=True ) - tf.square(tf.stop_gradient(mean)) @@ -674,6 +673,9 @@ def moments(x, axes, keepdims=False): mean = tf.squeeze(mean, axes) variance = tf.squeeze(variance, axes) if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max) + variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max) mean = cast(mean, ori_dtype) variance = cast(variance, ori_dtype) return mean, variance diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index cf288245b..295c88914 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -614,22 +614,21 @@ def binary_crossentropy(target, output, from_logits=False): def moments(x, axes, keepdims=False): x = convert_to_tensor(x) - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 need_cast = False ori_dtype = standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") - # Compute true mean while keeping the dims for proper broadcasting mean = torch.mean(x, dim=axes, keepdim=True) - # Sample variance, not unbiased variance - # Note: detach does not change the gradient that gets - # backpropagated to the mean from the variance calculation, - # because that gradient is zero + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. variance = torch.mean( torch.square(x), dim=axes, keepdim=True ) - torch.square(mean.detach()) @@ -638,6 +637,17 @@ def moments(x, axes, keepdims=False): mean = torch.squeeze(mean, axes) variance = torch.squeeze(variance, axes) if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = torch.clip( + mean, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) + variance = torch.clip( + variance, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) mean = cast(mean, ori_dtype) variance = cast(variance, ori_dtype) return mean, variance diff --git a/keras_core/ops/nn_test.py b/keras_core/ops/nn_test.py index efc373631..99233f41d 100644 --- a/keras_core/ops/nn_test.py +++ b/keras_core/ops/nn_test.py @@ -1186,13 +1186,13 @@ def test_moments(self): self.assertAllClose(mean, np.mean(x)) self.assertAllClose(variance, np.var(x)) - # Test batch statics for 4D moments (batch, height, width, channels) + # Test batch statistics for 4D moments (batch, height, width, channels) x = np.random.uniform(size=(2, 28, 28, 3)) mean, variance = knn.moments(x, axes=[0]) self.assertAllClose(mean, np.mean(x, axis=0)) self.assertAllClose(variance, np.var(x, axis=0)) - # Test global statics for 4D moments (batch, height, width, channels) + # Test global statistics for 4D moments (batch, height, width, channels) x = np.random.uniform(size=(2, 28, 28, 3)) mean, variance = knn.moments(x, axes=[0, 1, 2]) self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2))) @@ -1204,12 +1204,15 @@ def test_moments(self): self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True)) self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True)) - # Test float16 - x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float16) - mean, variance = knn.moments(x, axes=[0]) - expected_mean = np.mean(x.astype(np.float32), axis=0).astype(np.float16) - expected_variance = np.var(x.astype(np.float32), axis=0).astype( - np.float16 + # Test float16 which causes overflow + x = np.array( + [-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9], + dtype=np.float16, ) + mean, variance = knn.moments(x, axes=[0]) + expected_mean = np.mean(x.astype(np.float32)).astype(np.float16) + # the output variance is clipped to the max value of np.float16 because + # it is overflowed + expected_variance = np.finfo(np.float16).max self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_variance)