diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index c0e381d34..7fff15ada 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -5,6 +5,7 @@ from jax import nn as jnn from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -486,3 +487,39 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * jnp.log(output) bce += (1.0 - target) * jnp.log(1.0 - output) return -bce + + +def moments(x, axes, keepdims=False): + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = jnp.mean(x, axes, keepdims=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. + variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square( + jax.lax.stop_gradient(mean) + ) + + if not keepdims: + mean = jnp.squeeze(mean, axes) + variance = jnp.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = jnp.clip( + mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) + variance = jnp.clip( + variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/numpy/nn.py b/keras_core/backend/numpy/nn.py index b62521347..231575874 100644 --- a/keras_core/backend/numpy/nn.py +++ b/keras_core/backend/numpy/nn.py @@ -4,6 +4,7 @@ from jax import numpy as jnp from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) @@ -519,3 +520,34 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * np.log(output) bce += (1.0 - target) * np.log(1.0 - output) return -bce + + +def moments(x, axes, keepdims=False): + axes = tuple(axes) if isinstance(axes, list) else axes + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = np.mean(x, axes, keepdims=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) + + if not keepdims: + mean = np.squeeze(mean, axes) + variance = np.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) + variance = np.clip( + variance, np.finfo(np.float16).min, np.finfo(np.float16).max + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index e0d5454df..e3eb8f6ea 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -3,6 +3,7 @@ import tensorflow as tf from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_output_shape, ) @@ -646,3 +647,35 @@ def binary_crossentropy(target, output, from_logits=False): bce = target * tf.math.log(output) bce += (1 - target) * tf.math.log(1 - output) return -bce + + +def moments(x, axes, keepdims=False): + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = tf.reduce_mean(x, axes, keepdims=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. + variance = tf.reduce_mean( + tf.square(x), axis=axes, keepdims=True + ) - tf.square(tf.stop_gradient(mean)) + + if not keepdims: + mean = tf.squeeze(mean, axes) + variance = tf.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max) + variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index cb7cff213..295c88914 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -3,6 +3,7 @@ import torch.nn.functional as tnn from keras_core.backend import standardize_data_format +from keras_core.backend import standardize_dtype from keras_core.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_torch, ) @@ -609,3 +610,44 @@ def binary_crossentropy(target, output, from_logits=False): else: output = torch.clip(output, epsilon(), 1.0 - epsilon()) return tnn.binary_cross_entropy(output, target, reduction="none") + + +def moments(x, axes, keepdims=False): + x = convert_to_tensor(x) + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = torch.mean(x, dim=axes, keepdim=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. + variance = torch.mean( + torch.square(x), dim=axes, keepdim=True + ) - torch.square(mean.detach()) + + if not keepdims: + mean = torch.squeeze(mean, axes) + variance = torch.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = torch.clip( + mean, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) + variance = torch.clip( + variance, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index a447313c2..0812f6843 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -198,10 +198,9 @@ def call(self, inputs, training=None, mask=None): broadcast_shape = [1] * len(inputs.shape) broadcast_shape[self.axis] = inputs.shape[self.axis] if training and self.trainable: - mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True) - variance = ops.mean( - ops.square(inputs), axis=self._reduction_axes, keepdims=True - ) - ops.square(mean) + mean, variance = ops.moments( + inputs, axes=self._reduction_axes, keepdims=True + ) outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon) mean = ops.squeeze(mean, self._reduction_axes) variance = ops.squeeze(variance, self._reduction_axes) diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py index 1fed1abc5..94b56b05f 100644 --- a/keras_core/layers/normalization/group_normalization.py +++ b/keras_core/layers/normalization/group_normalization.py @@ -171,11 +171,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape): axis = -2 if self.axis == -1 else self.axis - 1 group_reduction_axes.pop(axis) - mean = ops.mean( - reshaped_inputs, axis=group_reduction_axes, keepdims=True - ) - variance = ops.var( - reshaped_inputs, axis=group_reduction_axes, keepdims=True + mean, variance = ops.moments( + reshaped_inputs, axes=group_reduction_axes, keepdims=True ) gamma, beta = self._get_reshaped_weights(input_shape) diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index fe0bc46df..7d8381c27 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -203,19 +203,17 @@ def _broadcast(v): # this is at least as numerically stable as the fused version. inputs = ops.cast(inputs, "float32") - # Calculate the variance last axis (layer activations). - variance = ops.var(inputs, axis=self.axis, keepdims=True) - - # Compute the batch normalization. - inv = 1 / ops.sqrt(variance + self.epsilon) - if self.rms_scaling: # Calculate outputs with only variance and gamma if rms scaling # is enabled + # Calculate the variance along last axis (layer activations). + variance = ops.var(inputs, axis=self.axis, keepdims=True) + inv = 1 / ops.sqrt(variance + self.epsilon) outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma else: - # Calculate the mean last axis (layer activations). - mean = ops.mean(inputs, axis=self.axis, keepdims=True) + # Calculate the mean & variance along last axis (layer activations). + mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) + inv = 1 / ops.sqrt(variance + self.epsilon) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: scale = ops.cast(scale, inputs.dtype) diff --git a/keras_core/ops/nn.py b/keras_core/ops/nn.py index 9b2f74e08..05dbba4fa 100644 --- a/keras_core/ops/nn.py +++ b/keras_core/ops/nn.py @@ -10,6 +10,7 @@ ) from keras_core.ops import operation_utils from keras_core.ops.operation import Operation +from keras_core.ops.operation_utils import reduce_shape class Relu(Operation): @@ -1551,7 +1552,6 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None): Returns: Tensor: The multi-hot encoded tensor. - Example: >>> data = keras_core.ops.convert_to_tensor([0, 4]) @@ -1563,3 +1563,60 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None): return MultiHot(num_tokens, axis, dtype).symbolic_call(inputs) return backend.nn.multi_hot(inputs, num_tokens, axis, dtype) + + +class Moments(Operation): + def __init__(self, axes, keepdims=False, name=None): + super().__init__(name) + self.axes = axes + self.keepdims = keepdims + + def call(self, x): + return backend.nn.moments(x, axes=self.axes, keepdims=self.keepdims) + + def compute_output_spec(self, x): + return ( + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + ) + + +@keras_core_export( + [ + "keras_core.ops.moments", + "keras_core.ops.nn.moments", + ] +) +def moments(x, axes, keepdims=False): + """Calculates the mean and variance of `x`. + + The mean and variance are calculated by aggregating the contents of `x` + across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and + variance of a vector. + + Args: + x: Input tensor. + axes: A list of axes which to compute mean and variance. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + A tuple containing two tensors - mean and variance. + + Example: + + >>> x = keras_core.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32") + >>> keras_core.ops.moments(x, axes=[0]) + (array(21.2, dtype=float32), array(1553.3601, dtype=float32)) + + """ + if any_symbolic_tensors((x,)): + return Moments(axes, keepdims).symbolic_call(x) + + return backend.nn.moments(x, axes, keepdims) diff --git a/keras_core/ops/nn_test.py b/keras_core/ops/nn_test.py index f15412d59..99233f41d 100644 --- a/keras_core/ops/nn_test.py +++ b/keras_core/ops/nn_test.py @@ -298,6 +298,20 @@ def test_one_hot_dtype(self, dtype): out = knn.one_hot(x, 5, axis=0, dtype=dtype) self.assertEqual(backend.standardize_dtype(out.dtype), dtype) + def test_moments(self): + x = KerasTensor([None, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + + self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4)) + self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,)) + self.assertEqual( + knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1) + ) + class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -591,6 +605,14 @@ def test_sparse_categorical_crossentropy(self): knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3) ) + def test_moments(self): + x = KerasTensor([2, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): @@ -1156,3 +1178,41 @@ def test_multi_hot(self): indices_1d = np.array([0, -1, -1, 3]) expected_output_1d = np.array([1, 0, 0, 1]) self.assertAllClose(knn.multi_hot(indices_1d, 4), expected_output_1d) + + def test_moments(self): + # Test 1D moments + x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x)) + self.assertAllClose(variance, np.var(x)) + + # Test batch statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x, axis=0)) + self.assertAllClose(variance, np.var(x, axis=0)) + + # Test global statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0, 1, 2]) + self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2))) + self.assertAllClose(variance, np.var(x, axis=(0, 1, 2))) + + # Test keepdims + x = np.random.uniform(size=(2, 28, 28, 3)) + mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True) + self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True)) + self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True)) + + # Test float16 which causes overflow + x = np.array( + [-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9], + dtype=np.float16, + ) + mean, variance = knn.moments(x, axes=[0]) + expected_mean = np.mean(x.astype(np.float32)).astype(np.float16) + # the output variance is clipped to the max value of np.float16 because + # it is overflowed + expected_variance = np.finfo(np.float16).max + self.assertAllClose(mean, expected_mean) + self.assertAllClose(variance, expected_variance)