Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ops.rsqrt, improve normalization layers and enable ops fusion in tflite #892

Merged
merged 6 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras_core.backend import standardize_dtype
from keras_core.backend.jax.math import fft as jax_fft
from keras_core.backend.jax.math import fft2 as jax_fft2
from keras_core.backend.jax.math import rsqrt as jax_rsqrt
from keras_core.backend.numpy.core import convert_to_tensor
from keras_core.utils.module_utils import scipy

Expand Down Expand Up @@ -298,3 +299,7 @@ def istft(
else:
end = expected_output_len
return x[..., start:end]


def rsqrt(x):
return np.array(jax_rsqrt(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not 1. / sqrt(x)? It's numpy native, and we're not worried about performance for the numpy backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of being consistent with other backends, but it should be okay to use 1. / sqrt(x)
Fixed.

31 changes: 19 additions & 12 deletions keras_core/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,21 @@ def call(self, inputs, training=None, mask=None):
mean, variance = ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
self.moving_mean.assign(
ops.cast(
moving_mean * self.momentum + mean * (1.0 - self.momentum),
moving_mean * self.momentum
+ ops.squeeze(mean, self._reduction_axes)
* (1.0 - self.momentum),
inputs.dtype,
)
)
self.moving_variance.assign(
ops.cast(
moving_variance * self.momentum
+ variance * (1.0 - self.momentum),
+ ops.squeeze(variance, self._reduction_axes)
* (1.0 - self.momentum),
inputs.dtype,
)
)
Expand All @@ -224,17 +224,24 @@ def call(self, inputs, training=None, mask=None):
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
moving_mean = ops.reshape(moving_mean, broadcast_shape)
moving_variance = ops.reshape(moving_variance, broadcast_shape)
outputs = (inputs - moving_mean) / ops.sqrt(
moving_variance + self.epsilon
)
mean = moving_mean
variance = moving_variance

inv = ops.rsqrt(variance + self.epsilon)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, outputs.dtype)
outputs = outputs * gamma
gamma = ops.cast(gamma, inputs.dtype)
inv = inv * gamma

res = -mean * inv
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
beta = ops.cast(beta, outputs.dtype)
outputs = outputs + beta
beta = ops.cast(beta, inputs.dtype)
res = res + beta

# Note: Folding BatchNormalization depends on the precise order of ops
# that are generated by the expression below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good comment!

outputs = inputs * inv + res
return ops.cast(outputs, input_dtype)

def get_config(self):
Expand Down
33 changes: 11 additions & 22 deletions keras_core/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,37 +171,26 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)

broadcast_shape = self._create_broadcast_shape(input_shape)
mean, variance = ops.moments(
reshaped_inputs, axes=group_reduction_axes, keepdims=True
)
gamma, beta = self._get_reshaped_weights(input_shape)

# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)

if gamma is not None:
inv = ops.multiply(inv, gamma)

if beta is not None:
x = beta - ops.multiply(mean, inv)
else:
x = -ops.multiply(mean, inv)

normalized_inputs = reshaped_inputs * ops.cast(
inv, reshaped_inputs.dtype
) + ops.cast(x, reshaped_inputs.dtype)
normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype)
return normalized_inputs

def _get_reshaped_weights(self, input_shape):
broadcast_shape = self._create_broadcast_shape(input_shape)
gamma = None
beta = None
inv = ops.rsqrt(variance + self.epsilon)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, reshaped_inputs.dtype)
inv = inv * gamma

res = -mean * inv
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
return gamma, beta
beta = ops.cast(beta, reshaped_inputs.dtype)
res = res + beta

normalized_inputs = reshaped_inputs * inv + res
return normalized_inputs

def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
Expand Down
45 changes: 45 additions & 0 deletions keras_core/layers/normalization/group_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,51 @@ def test_groupnorm(self):
supports_masking=True,
)

def test_undefined_dim_error(self):
inputs = layers.Input(shape=(2, 2, 2, None))
layer = layers.GroupNormalization()
with self.assertRaisesRegex(
ValueError,
(
"input tensor should have a defined dimension but the layer "
"received an input with shape"
),
):
_ = layer(inputs)

def test_groups_bigger_than_dim_error(self):
inputs = np.ones(shape=(2, 2, 2, 4))
layer = layers.GroupNormalization(groups=5)
with self.assertRaisesRegex(
ValueError,
"cannot be more than the number of channels",
):
_ = layer(inputs)

def test_groups_not_a_multiple_of_dim_error(self):
inputs = np.ones(shape=(2, 2, 2, 4))
layer = layers.GroupNormalization(groups=3)
with self.assertRaisesRegex(
ValueError,
"must be a multiple of the number of channels",
):
_ = layer(inputs)

def test_groups_instance_norm(self):
# GroupNormalization with groups=-1 will become InstanceNormalization
instance_norm_layer_1 = layers.GroupNormalization(
groups=-1, axis=-1, scale=False, center=False
)
instance_norm_layer_2 = layers.GroupNormalization(
groups=4, axis=-1, scale=False, center=False
)
inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]])

outputs_1 = instance_norm_layer_1(inputs)
outputs_2 = instance_norm_layer_2(inputs)

self.assertAllClose(outputs_1, outputs_2)

def test_correctness_instance_norm(self):
instance_norm_layer = layers.GroupNormalization(
groups=4, axis=-1, scale=False, center=False
Expand Down
38 changes: 17 additions & 21 deletions keras_core/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,33 +206,29 @@ def _broadcast(v):
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along last axis (layer activations).
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma
inv = ops.rsqrt(variance + self.epsilon)

outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
else:
# Calculate the mean & variance along last axis (layer activations).
# Calculate the mean & variance along self.axis (layer activations).
mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)
inv = 1 / ops.sqrt(variance + self.epsilon)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
inv = inv * scale
x = -mean * inv
if offset is not None:
offset = ops.cast(offset, inputs.dtype)
x = offset + x

outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
x, inputs.dtype
)
gamma, beta = _broadcast(self.gamma), _broadcast(self.beta)

inv = ops.rsqrt(variance + self.epsilon)
if gamma is not None:
gamma = ops.cast(gamma, inputs.dtype)
inv = inv * gamma

outputs = ops.cast(outputs, input_dtype)
res = -mean * inv
if beta is not None:
beta = ops.cast(beta, inputs.dtype)
res = res + beta

# If some components of the shape got lost due to adjustments, fix that.
outputs = ops.reshape(outputs, ops.shape(inputs))
outputs = inputs * inv + res

return outputs
return ops.cast(outputs, input_dtype)

def compute_output_shape(self, input_shape):
return input_shape
Expand Down
10 changes: 10 additions & 0 deletions keras_core/layers/normalization/layer_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def test_ln_basics(self):
supports_masking=True,
)

def test_invalid_axis(self):
with self.assertRaisesRegex(
TypeError,
(
"Expected an int or a list/tuple of ints for the argument "
"'axis'"
),
):
layers.LayerNormalization(axis={"axis": -1})

def test_correctness(self):
layer = layers.LayerNormalization(dtype="float32")
layer.build(input_shape=(2, 2, 2))
Expand Down
26 changes: 26 additions & 0 deletions keras_core/layers/normalization/spectral_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@ def test_basic_spectralnorm(self):
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.SpectralNormalization,
init_kwargs={"layer": layers.Embedding(10, 4)},
input_data=np.random.randint(10, size=(10,)),
expected_output_shape=(10, 4),
expected_num_trainable_weights=1,
expected_num_non_trainable_weights=1,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
run_training_check=False,
)

def test_invalid_power_iterations(self):
with self.assertRaisesRegex(
ValueError, "`power_iterations` should be greater than zero."
):
layers.SpectralNormalization(layers.Dense(2), power_iterations=0)

def test_invalid_layer(self):
layer = layers.SpectralNormalization(layers.ReLU())
inputs = np.ones(shape=(4, 2))
with self.assertRaisesRegex(
ValueError, "object has no attribute 'kernel' nor 'embeddings'"
):
layer(inputs)

def test_apply_layer(self):
images = np.ones((1, 2, 2, 1))
Expand Down
2 changes: 1 addition & 1 deletion keras_core/layers/normalization/unit_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def call(self, inputs):
x = ops.cast(inputs, self.compute_dtype)

square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True)
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
x_inv_norm = ops.rsqrt(ops.maximum(square_sum, 1e-12))
return ops.multiply(x, x_inv_norm)

def compute_output_shape(self, input_shape):
Expand Down
10 changes: 10 additions & 0 deletions keras_core/layers/normalization/unit_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_un_basics(self):
supports_masking=True,
)

def test_invalid_axis(self):
with self.assertRaisesRegex(
TypeError,
(
"Invalid value for `axis` argument: expected an int or a "
"list/tuple of ints."
),
):
layers.UnitNormalization(axis={"axis": -1})

def test_correctness(self):
layer = layers.UnitNormalization(axis=-1)
inputs = np.random.normal(size=(2, 3))
Expand Down
4 changes: 0 additions & 4 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,10 +831,6 @@ def test_istft(
ref = ref[..., truncated_len:-truncated_len]
self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy does not support rsqrt.",
)
def test_rsqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x))
Expand Down