Skip to content

Commit

Permalink
Fix MHA layer return_attn_scores symbolic call.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 12, 2024
1 parent cf8ca3c commit dca1d8a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/callbacks/backup_and_restore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class BackupAndRestoreCallbackTest(testing.TestCase):
def make_model(self):
model = Sequential(
[
layers.Input((3,)),
CanaryLayer(),
layers.Dense(1),
]
Expand Down
32 changes: 32 additions & 0 deletions keras/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from keras import backend
from keras import constraints
from keras import initializers
from keras import ops
Expand Down Expand Up @@ -603,6 +604,37 @@ def compute_output_shape(

return query_shape

def compute_output_spec(
self,
query,
value,
key=None,
query_mask=None,
value_mask=None,
key_mask=None,
attention_mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
if key is not None:
key_shape = key.shape
else:
key_shape = None
output_shape = self.compute_output_shape(
query.shape, value.shape, key_shape
)
output_spec = backend.KerasTensor(
output_shape, dtype=self.compute_dtype
)
if return_attention_scores:
length = query.shape[1]
attention_shape = (query.shape[0], self.num_heads, length, length)
return output_spec, backend.KerasTensor(
attention_shape, dtype=self.compute_dtype
)
return output_spec


def _index_to_einsum_variable(i):
"""Coverts an index to a einsum variable name.
Expand Down
15 changes: 15 additions & 0 deletions keras/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,18 @@ def test_lora(self):
new_model.save_weights(temp_filepath)
model.load_weights(temp_filepath)
self.assertAllClose(model.predict(x), new_model.predict(x))

@parameterized.parameters([((1, 2, 3),), ((2, 3, 5),)])
def test_symbolic_return_attention_scores(self, shape):
mha = layers.MultiHeadAttention(num_heads=4, key_dim=2)
x = layers.Input(batch_shape=shape)
y = layers.Input(batch_shape=shape)
symbolic_out = mha(x, y, return_attention_scores=True)
self.assertLen(symbolic_out, 2)

x = np.random.random(shape)
y = np.random.random(shape)
out = mha(x, y, return_attention_scores=True)
self.assertLen(out, 2)
self.assertEqual(symbolic_out[0].shape, out[0].shape)
self.assertEqual(symbolic_out[1].shape, out[1].shape)

0 comments on commit dca1d8a

Please sign in to comment.