From 40322a7e249697b97282fe03a03168ebe6df9e5b Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Mon, 23 Dec 2024 20:59:17 +0100 Subject: [PATCH 1/4] Fix: Ensure Attention Layer Returns Attention Scores when `return_attention_scores=True` This pull request addresses an issue in the Attention layer where the return_attention_scores parameter wasn't correctly handled in the compute_output_shape method. This fix ensures that attention scores are returned when return_attention_scores=True. ## Changes Made Modified compute_output_shape method to return the shape of both the attention output and the attention scores when return_attention_scores=True. --- keras/src/layers/attention/attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 592468fe802e..dcf8a4acb97a 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -244,8 +244,11 @@ def compute_mask(self, inputs, mask=None): return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): - """Returns shape of value tensor dim, but for query tensor length""" - return (*input_shape[0][:-1], input_shape[1][-1]) + output_shape = (*input_shape[0][:-1], input_shape[1][-1]) + if self.return_attention_scores: + scores_shape = (input_shape[0][0], input_shape[0][1], input_shape[1][1]) + return output_shape, scores_shape + return output_shape def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" From cb8fe894c4a8afa7ac378161ad9bbe2ae2a52c6e Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Mon, 23 Dec 2024 20:23:22 +0000 Subject: [PATCH 2/4] Formatting --- keras/src/layers/attention/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index dcf8a4acb97a..13ab65acb521 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -246,7 +246,11 @@ def compute_mask(self, inputs, mask=None): def compute_output_shape(self, input_shape): output_shape = (*input_shape[0][:-1], input_shape[1][-1]) if self.return_attention_scores: - scores_shape = (input_shape[0][0], input_shape[0][1], input_shape[1][1]) + scores_shape = ( + input_shape[0][0], + input_shape[0][1], + input_shape[1][1], + ) return output_shape, scores_shape return output_shape From 492ebf0780404897adb5588984693e628bb751e0 Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Mon, 23 Dec 2024 22:10:25 +0000 Subject: [PATCH 3/4] Fixed score return and added unit tests for return_attention_scores=True --- keras/src/layers/attention/attention.py | 62 ++++++++++++++++---- keras/src/layers/attention/attention_test.py | 59 +++++++++++++++++++ 2 files changed, 110 insertions(+), 11 deletions(-) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 13ab65acb521..923dcce9a1b6 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -1,6 +1,7 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor from keras.src.layers.layer import Layer @@ -84,6 +85,8 @@ def __init__( f"Received: score_mode={score_mode}" ) + self._return_attention_scores = False + def build(self, input_shape): self._validate_inputs(input_shape) self.scale = None @@ -137,6 +140,7 @@ def _calculate_scores(self, query, key): else: raise ValueError("scores not computed") + print("scores", scores) return scores def _apply_scores(self, scores, value, scores_mask=None, training=False): @@ -217,6 +221,7 @@ def call( use_causal_mask=False, ): self._validate_inputs(inputs=inputs, mask=mask) + self._return_attention_scores = return_attention_scores q = inputs[0] v = inputs[1] k = inputs[2] if len(inputs) > 2 else v @@ -226,16 +231,17 @@ def call( scores_mask = self._calculate_score_mask( scores, v_mask, use_causal_mask ) - result, attention_scores = self._apply_scores( + attention_output, attention_scores = self._apply_scores( scores=scores, value=v, scores_mask=scores_mask, training=training ) if q_mask is not None: # Mask of shape [batch_size, Tq, 1]. q_mask = ops.expand_dims(q_mask, axis=-1) - result *= ops.cast(q_mask, dtype=result.dtype) + attention_output *= ops.cast(q_mask, dtype=attention_output.dtype) if return_attention_scores: - return result, attention_scores - return result + return (attention_output, attention_scores) + else: + return attention_output def compute_mask(self, inputs, mask=None): self._validate_inputs(inputs=inputs, mask=mask) @@ -244,16 +250,50 @@ def compute_mask(self, inputs, mask=None): return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): - output_shape = (*input_shape[0][:-1], input_shape[1][-1]) - if self.return_attention_scores: - scores_shape = ( - input_shape[0][0], - input_shape[0][1], - input_shape[1][1], - ) + query_shape, value_shape, key_shape = input_shape + if key_shape is None: + key_shape = value_shape + + output_shape = (*query_shape[:-1], value_shape[-1]) + if self._return_attention_scores: + scores_shape = (query_shape[0], query_shape[1], key_shape[1]) return output_shape, scores_shape return output_shape + def compute_output_spec( + self, + inputs, + mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + # Validate and unpack inputs + self._validate_inputs(inputs, mask) + query = inputs[0] + value = inputs[1] + key = inputs[2] if len(inputs) > 2 else value + + # Compute primary output shape + output_shape = self.compute_output_shape( + [query.shape, value.shape, key.shape] + ) + output_spec = KerasTensor(output_shape, dtype=self.compute_dtype) + + # Handle attention scores if requested + if self._return_attention_scores: + scores_shape = ( + query.shape[0], + query.shape[1], + key.shape[1], + ) # (batch_size, Tq, Tv) + attention_scores_spec = KerasTensor( + scores_shape, dtype=self.compute_dtype + ) + return (output_spec, attention_scores_spec) + + return output_spec + def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" class_name = self.__class__.__name__ diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index de8dba643405..eab40b2a0386 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -358,3 +358,62 @@ def test_attention_compute_output_shape(self): ), output.shape, ) + + def test_return_attention_scores_true(self): + """Test that the layer returns attention scores along with outputs.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + output, attention_scores = layer( + [query, value], return_attention_scores=True + ) + + # Check the shape of the outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_true_and_tuple(self): + """Test that the layer outputs are a tuple when + return_attention_scores=True.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Check that outputs is a tuple + self.assertIsInstance( + outputs, tuple, "Expected the outputs to be a tuple" + ) + + def test_return_attention_scores_true_tuple_then_unpack(self): + """Test that outputs can be unpacked correctly.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Unpack the outputs + output, attention_scores = outputs + + # Check the shape of the unpacked outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape From 8a66542b1c4a62a5c0eb7a84805528e54234783d Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Mon, 23 Dec 2024 22:46:40 +0000 Subject: [PATCH 4/4] Removed debug print statement --- keras/src/layers/attention/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 923dcce9a1b6..15ff906e5922 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -140,7 +140,6 @@ def _calculate_scores(self, query, key): else: raise ValueError("scores not computed") - print("scores", scores) return scores def _apply_scores(self, scores, value, scores_mask=None, training=False):