Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Return Attention Scores when return_attention_scores=True #20684

Merged
merged 4 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.layers.layer import Layer


Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
f"Received: score_mode={score_mode}"
)

self._return_attention_scores = False

def build(self, input_shape):
self._validate_inputs(input_shape)
self.scale = None
Expand Down Expand Up @@ -217,6 +220,7 @@ def call(
use_causal_mask=False,
):
self._validate_inputs(inputs=inputs, mask=mask)
self._return_attention_scores = return_attention_scores
q = inputs[0]
v = inputs[1]
k = inputs[2] if len(inputs) > 2 else v
Expand All @@ -226,16 +230,17 @@ def call(
scores_mask = self._calculate_score_mask(
scores, v_mask, use_causal_mask
)
result, attention_scores = self._apply_scores(
attention_output, attention_scores = self._apply_scores(
scores=scores, value=v, scores_mask=scores_mask, training=training
)
if q_mask is not None:
# Mask of shape [batch_size, Tq, 1].
q_mask = ops.expand_dims(q_mask, axis=-1)
result *= ops.cast(q_mask, dtype=result.dtype)
attention_output *= ops.cast(q_mask, dtype=attention_output.dtype)
if return_attention_scores:
return result, attention_scores
return result
return (attention_output, attention_scores)
else:
return attention_output

def compute_mask(self, inputs, mask=None):
self._validate_inputs(inputs=inputs, mask=mask)
Expand All @@ -244,8 +249,49 @@ def compute_mask(self, inputs, mask=None):
return ops.convert_to_tensor(mask[0])

def compute_output_shape(self, input_shape):
"""Returns shape of value tensor dim, but for query tensor length"""
return (*input_shape[0][:-1], input_shape[1][-1])
query_shape, value_shape, key_shape = input_shape
if key_shape is None:
key_shape = value_shape

output_shape = (*query_shape[:-1], value_shape[-1])
if self._return_attention_scores:
scores_shape = (query_shape[0], query_shape[1], key_shape[1])
return output_shape, scores_shape
return output_shape

def compute_output_spec(
self,
inputs,
mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
# Validate and unpack inputs
self._validate_inputs(inputs, mask)
query = inputs[0]
value = inputs[1]
key = inputs[2] if len(inputs) > 2 else value

# Compute primary output shape
output_shape = self.compute_output_shape(
[query.shape, value.shape, key.shape]
)
output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)

# Handle attention scores if requested
if self._return_attention_scores:
scores_shape = (
query.shape[0],
query.shape[1],
key.shape[1],
) # (batch_size, Tq, Tv)
attention_scores_spec = KerasTensor(
scores_shape, dtype=self.compute_dtype
)
return (output_spec, attention_scores_spec)

return output_spec

def _validate_inputs(self, inputs, mask=None):
"""Validates arguments of the call method."""
Expand Down
59 changes: 59 additions & 0 deletions keras/src/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,62 @@ def test_attention_compute_output_shape(self):
),
output.shape,
)

def test_return_attention_scores_true(self):
"""Test that the layer returns attention scores along with outputs."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
output, attention_scores = layer(
[query, value], return_attention_scores=True
)

# Check the shape of the outputs
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
self.assertEqual(
attention_scores.shape, (2, 8, 4)
) # Attention scores shape

def test_return_attention_scores_true_and_tuple(self):
"""Test that the layer outputs are a tuple when
return_attention_scores=True."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
outputs = layer([query, value], return_attention_scores=True)

# Check that outputs is a tuple
self.assertIsInstance(
outputs, tuple, "Expected the outputs to be a tuple"
)

def test_return_attention_scores_true_tuple_then_unpack(self):
"""Test that outputs can be unpacked correctly."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
outputs = layer([query, value], return_attention_scores=True)

# Unpack the outputs
output, attention_scores = outputs

# Check the shape of the unpacked outputs
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
self.assertEqual(
attention_scores.shape, (2, 8, 4)
) # Attention scores shape
Loading