Skip to content

Commit

Permalink
Add ops.nn.moments and speed-up normalization layers (#866)
Browse files Browse the repository at this point in the history
* Add `ops.nn.moments`

* Delete empty lines

* Improve performance

* Address the issue of overflow and underflow when casting
  • Loading branch information
james77777778 authored Sep 14, 2023
1 parent bb21710 commit e8db3b6
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 18 deletions.
37 changes: 37 additions & 0 deletions keras_core/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import nn as jnn

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_jax,
)
Expand Down Expand Up @@ -486,3 +487,39 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * jnp.log(output)
bce += (1.0 - target) * jnp.log(1.0 - output)
return -bce


def moments(x, axes, keepdims=False):
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = jnp.mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square(
jax.lax.stop_gradient(mean)
)

if not keepdims:
mean = jnp.squeeze(mean, axes)
variance = jnp.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = jnp.clip(
mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max
)
variance = jnp.clip(
variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
32 changes: 32 additions & 0 deletions keras_core/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax import numpy as jnp

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_jax,
)
Expand Down Expand Up @@ -519,3 +520,34 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * np.log(output)
bce += (1.0 - target) * np.log(1.0 - output)
return -bce


def moments(x, axes, keepdims=False):
axes = tuple(axes) if isinstance(axes, list) else axes
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = np.mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean)

if not keepdims:
mean = np.squeeze(mean, axes)
variance = np.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max)
variance = np.clip(
variance, np.finfo(np.float16).min, np.finfo(np.float16).max
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
33 changes: 33 additions & 0 deletions keras_core/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tensorflow as tf

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
Expand Down Expand Up @@ -646,3 +647,35 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * tf.math.log(output)
bce += (1 - target) * tf.math.log(1 - output)
return -bce


def moments(x, axes, keepdims=False):
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = tf.reduce_mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = tf.reduce_mean(
tf.square(x), axis=axes, keepdims=True
) - tf.square(tf.stop_gradient(mean))

if not keepdims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max)
variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
42 changes: 42 additions & 0 deletions keras_core/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as tnn

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_torch,
)
Expand Down Expand Up @@ -609,3 +610,44 @@ def binary_crossentropy(target, output, from_logits=False):
else:
output = torch.clip(output, epsilon(), 1.0 - epsilon())
return tnn.binary_cross_entropy(output, target, reduction="none")


def moments(x, axes, keepdims=False):
x = convert_to_tensor(x)
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = torch.mean(x, dim=axes, keepdim=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = torch.mean(
torch.square(x), dim=axes, keepdim=True
) - torch.square(mean.detach())

if not keepdims:
mean = torch.squeeze(mean, axes)
variance = torch.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = torch.clip(
mean,
torch.finfo(torch.float16).min,
torch.finfo(torch.float16).max,
)
variance = torch.clip(
variance,
torch.finfo(torch.float16).min,
torch.finfo(torch.float16).max,
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
7 changes: 3 additions & 4 deletions keras_core/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,9 @@ def call(self, inputs, training=None, mask=None):
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
if training and self.trainable:
mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True)
variance = ops.mean(
ops.square(inputs), axis=self._reduction_axes, keepdims=True
) - ops.square(mean)
mean, variance = ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
Expand Down
7 changes: 2 additions & 5 deletions keras_core/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

mean = ops.mean(
reshaped_inputs, axis=group_reduction_axes, keepdims=True
)
variance = ops.var(
reshaped_inputs, axis=group_reduction_axes, keepdims=True
mean, variance = ops.moments(
reshaped_inputs, axes=group_reduction_axes, keepdims=True
)
gamma, beta = self._get_reshaped_weights(input_shape)

Expand Down
14 changes: 6 additions & 8 deletions keras_core/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,17 @@ def _broadcast(v):
# this is at least as numerically stable as the fused version.
inputs = ops.cast(inputs, "float32")

# Calculate the variance last axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)

# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)

if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along last axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma
else:
# Calculate the mean last axis (layer activations).
mean = ops.mean(inputs, axis=self.axis, keepdims=True)
# Calculate the mean & variance along last axis (layer activations).
mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
Expand Down
59 changes: 58 additions & 1 deletion keras_core/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from keras_core.ops import operation_utils
from keras_core.ops.operation import Operation
from keras_core.ops.operation_utils import reduce_shape


class Relu(Operation):
Expand Down Expand Up @@ -1551,7 +1552,6 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None):
Returns:
Tensor: The multi-hot encoded tensor.
Example:
>>> data = keras_core.ops.convert_to_tensor([0, 4])
Expand All @@ -1563,3 +1563,60 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None):
return MultiHot(num_tokens, axis, dtype).symbolic_call(inputs)

return backend.nn.multi_hot(inputs, num_tokens, axis, dtype)


class Moments(Operation):
def __init__(self, axes, keepdims=False, name=None):
super().__init__(name)
self.axes = axes
self.keepdims = keepdims

def call(self, x):
return backend.nn.moments(x, axes=self.axes, keepdims=self.keepdims)

def compute_output_spec(self, x):
return (
KerasTensor(
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
dtype=x.dtype,
),
KerasTensor(
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
dtype=x.dtype,
),
)


@keras_core_export(
[
"keras_core.ops.moments",
"keras_core.ops.nn.moments",
]
)
def moments(x, axes, keepdims=False):
"""Calculates the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x`
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and
variance of a vector.
Args:
x: Input tensor.
axes: A list of axes which to compute mean and variance.
keepdims: If this is set to `True`, the axes which are reduced are left
in the result as dimensions with size one.
Returns:
A tuple containing two tensors - mean and variance.
Example:
>>> x = keras_core.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32")
>>> keras_core.ops.moments(x, axes=[0])
(array(21.2, dtype=float32), array(1553.3601, dtype=float32))
"""
if any_symbolic_tensors((x,)):
return Moments(axes, keepdims).symbolic_call(x)

return backend.nn.moments(x, axes, keepdims)
60 changes: 60 additions & 0 deletions keras_core/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def test_one_hot_dtype(self, dtype):
out = knn.one_hot(x, 5, axis=0, dtype=dtype)
self.assertEqual(backend.standardize_dtype(out.dtype), dtype)

def test_moments(self):
x = KerasTensor([None, 3, 4])
self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))
self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))
self.assertEqual(
knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)
)

self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4))
self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,))
self.assertEqual(
knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1)
)


class NNOpsStaticShapeTest(testing.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -591,6 +605,14 @@ def test_sparse_categorical_crossentropy(self):
knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3)
)

def test_moments(self):
x = KerasTensor([2, 3, 4])
self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))
self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))
self.assertEqual(
knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)
)


class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -1156,3 +1178,41 @@ def test_multi_hot(self):
indices_1d = np.array([0, -1, -1, 3])
expected_output_1d = np.array([1, 0, 0, 1])
self.assertAllClose(knn.multi_hot(indices_1d, 4), expected_output_1d)

def test_moments(self):
# Test 1D moments
x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32)
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x))
self.assertAllClose(variance, np.var(x))

# Test batch statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x, axis=0))
self.assertAllClose(variance, np.var(x, axis=0))

# Test global statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0, 1, 2])
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2)))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2)))

# Test keepdims
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True)
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True))

# Test float16 which causes overflow
x = np.array(
[-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9],
dtype=np.float16,
)
mean, variance = knn.moments(x, axes=[0])
expected_mean = np.mean(x.astype(np.float32)).astype(np.float16)
# the output variance is clipped to the max value of np.float16 because
# it is overflowed
expected_variance = np.finfo(np.float16).max
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_variance)

0 comments on commit e8db3b6

Please sign in to comment.