Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.nn.moments and speed-up normalization layers #866

Merged
merged 5 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions keras_core/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import nn as jnn

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_jax,
)
Expand Down Expand Up @@ -486,3 +487,39 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * jnp.log(output)
bce += (1.0 - target) * jnp.log(1.0 - output)
return -bce


def moments(x, axes, keepdims=False):
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = jnp.mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - jnp.square(
jax.lax.stop_gradient(mean)
)

if not keepdims:
mean = jnp.squeeze(mean, axes)
variance = jnp.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = jnp.clip(
mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max
)
variance = jnp.clip(
variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
32 changes: 32 additions & 0 deletions keras_core/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax import numpy as jnp

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_jax,
)
Expand Down Expand Up @@ -519,3 +520,34 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * np.log(output)
bce += (1.0 - target) * np.log(1.0 - output)
return -bce


def moments(x, axes, keepdims=False):
axes = tuple(axes) if isinstance(axes, list) else axes
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = np.mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean)

if not keepdims:
mean = np.squeeze(mean, axes)
variance = np.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max)
variance = np.clip(
variance, np.finfo(np.float16).min, np.finfo(np.float16).max
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
33 changes: 33 additions & 0 deletions keras_core/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tensorflow as tf

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
Expand Down Expand Up @@ -646,3 +647,35 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * tf.math.log(output)
bce += (1 - target) * tf.math.log(1 - output)
return -bce


def moments(x, axes, keepdims=False):
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = tf.reduce_mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = tf.reduce_mean(
tf.square(x), axis=axes, keepdims=True
) - tf.square(tf.stop_gradient(mean))

if not keepdims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max)
variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
42 changes: 42 additions & 0 deletions keras_core/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as tnn

from keras_core.backend import standardize_data_format
from keras_core.backend import standardize_dtype
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_torch,
)
Expand Down Expand Up @@ -609,3 +610,44 @@ def binary_crossentropy(target, output, from_logits=False):
else:
output = torch.clip(output, epsilon(), 1.0 - epsilon())
return tnn.binary_cross_entropy(output, target, reduction="none")


def moments(x, axes, keepdims=False):
x = convert_to_tensor(x)
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = torch.mean(x, dim=axes, keepdim=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = torch.mean(
torch.square(x), dim=axes, keepdim=True
) - torch.square(mean.detach())

if not keepdims:
mean = torch.squeeze(mean, axes)
variance = torch.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = torch.clip(
mean,
torch.finfo(torch.float16).min,
torch.finfo(torch.float16).max,
)
variance = torch.clip(
variance,
torch.finfo(torch.float16).min,
torch.finfo(torch.float16).max,
)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)
return mean, variance
7 changes: 3 additions & 4 deletions keras_core/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,9 @@ def call(self, inputs, training=None, mask=None):
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
if training and self.trainable:
mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True)
variance = ops.mean(
ops.square(inputs), axis=self._reduction_axes, keepdims=True
) - ops.square(mean)
mean, variance = ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
Expand Down
7 changes: 2 additions & 5 deletions keras_core/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

mean = ops.mean(
reshaped_inputs, axis=group_reduction_axes, keepdims=True
)
variance = ops.var(
reshaped_inputs, axis=group_reduction_axes, keepdims=True
mean, variance = ops.moments(
reshaped_inputs, axes=group_reduction_axes, keepdims=True
)
gamma, beta = self._get_reshaped_weights(input_shape)

Expand Down
14 changes: 6 additions & 8 deletions keras_core/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,17 @@ def _broadcast(v):
# this is at least as numerically stable as the fused version.
inputs = ops.cast(inputs, "float32")

# Calculate the variance last axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)

# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)

if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along last axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma
else:
# Calculate the mean last axis (layer activations).
mean = ops.mean(inputs, axis=self.axis, keepdims=True)
# Calculate the mean & variance along last axis (layer activations).
mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
Expand Down
59 changes: 58 additions & 1 deletion keras_core/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from keras_core.ops import operation_utils
from keras_core.ops.operation import Operation
from keras_core.ops.operation_utils import reduce_shape


class Relu(Operation):
Expand Down Expand Up @@ -1551,7 +1552,6 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None):
Returns:
Tensor: The multi-hot encoded tensor.


Example:

>>> data = keras_core.ops.convert_to_tensor([0, 4])
Expand All @@ -1563,3 +1563,60 @@ def multi_hot(inputs, num_tokens, axis=-1, dtype=None):
return MultiHot(num_tokens, axis, dtype).symbolic_call(inputs)

return backend.nn.multi_hot(inputs, num_tokens, axis, dtype)


class Moments(Operation):
def __init__(self, axes, keepdims=False, name=None):
super().__init__(name)
self.axes = axes
self.keepdims = keepdims

def call(self, x):
return backend.nn.moments(x, axes=self.axes, keepdims=self.keepdims)

def compute_output_spec(self, x):
return (
KerasTensor(
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
dtype=x.dtype,
),
KerasTensor(
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
dtype=x.dtype,
),
)


@keras_core_export(
[
"keras_core.ops.moments",
"keras_core.ops.nn.moments",
]
)
def moments(x, axes, keepdims=False):
"""Calculates the mean and variance of `x`.

The mean and variance are calculated by aggregating the contents of `x`
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and
variance of a vector.

Args:
x: Input tensor.
axes: A list of axes which to compute mean and variance.
keepdims: If this is set to `True`, the axes which are reduced are left
in the result as dimensions with size one.

Returns:
A tuple containing two tensors - mean and variance.

Example:

>>> x = keras_core.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32")
>>> keras_core.ops.moments(x, axes=[0])
(array(21.2, dtype=float32), array(1553.3601, dtype=float32))

"""
if any_symbolic_tensors((x,)):
return Moments(axes, keepdims).symbolic_call(x)

return backend.nn.moments(x, axes, keepdims)
60 changes: 60 additions & 0 deletions keras_core/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def test_one_hot_dtype(self, dtype):
out = knn.one_hot(x, 5, axis=0, dtype=dtype)
self.assertEqual(backend.standardize_dtype(out.dtype), dtype)

def test_moments(self):
x = KerasTensor([None, 3, 4])
self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))
self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))
self.assertEqual(
knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)
)

self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4))
self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,))
self.assertEqual(
knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1)
)


class NNOpsStaticShapeTest(testing.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -591,6 +605,14 @@ def test_sparse_categorical_crossentropy(self):
knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3)
)

def test_moments(self):
x = KerasTensor([2, 3, 4])
self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))
self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))
self.assertEqual(
knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)
)


class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -1156,3 +1178,41 @@ def test_multi_hot(self):
indices_1d = np.array([0, -1, -1, 3])
expected_output_1d = np.array([1, 0, 0, 1])
self.assertAllClose(knn.multi_hot(indices_1d, 4), expected_output_1d)

def test_moments(self):
# Test 1D moments
x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32)
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x))
self.assertAllClose(variance, np.var(x))

# Test batch statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x, axis=0))
self.assertAllClose(variance, np.var(x, axis=0))

# Test global statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0, 1, 2])
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2)))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2)))

# Test keepdims
x = np.random.uniform(size=(2, 28, 28, 3))
mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True)
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True))

# Test float16 which causes overflow
x = np.array(
[-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9],
dtype=np.float16,
)
mean, variance = knn.moments(x, axes=[0])
expected_mean = np.mean(x.astype(np.float32)).astype(np.float16)
# the output variance is clipped to the max value of np.float16 because
# it is overflowed
expected_variance = np.finfo(np.float16).max
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_variance)