Skip to content

Commit

Permalink
Fix Layer normalization issue with scalar mean & variance (#20626)
Browse files Browse the repository at this point in the history
* Fix normalization issue with scalar mean & variance

* Add unit test for normalization with scalar mean and variance
  • Loading branch information
Surya2k1 authored Dec 11, 2024
1 parent 57e29a6 commit aab9458
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras/src/layers/preprocessing/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def build(self, input_shape):
# with proper broadcast shape for use during call.
mean = ops.convert_to_tensor(self.input_mean)
variance = ops.convert_to_tensor(self.input_variance)
mean = ops.reshape(mean, self._broadcast_shape)
variance = ops.reshape(variance, self._broadcast_shape)
mean = ops.broadcast_to(mean, self._broadcast_shape)
variance = ops.broadcast_to(variance, self._broadcast_shape)
self.mean = ops.cast(mean, dtype=self.compute_dtype)
self.variance = ops.cast(variance, dtype=self.compute_dtype)
self.built = True
Expand Down
5 changes: 5 additions & 0 deletions keras/src/layers/preprocessing/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,8 @@ def test_tf_data_compatibility(self):
)
for output in ds.map(layer).take(1):
output.numpy()

def test_normalization_with_scalar_mean_var(self):
input_data = np.array([[1,2,3]], dtype='float32')
layer = layers.Normalization(mean=3., variance=2.)
layer(input_data)

0 comments on commit aab9458

Please sign in to comment.