Skip to content

Commit

Permalink
Remove output_shape property in MHA (#20543)
Browse files Browse the repository at this point in the history
* Simplify `output_shape` logic in MHA and remove `output_shape` property.

* Fix CI

* Update test

* Update test
  • Loading branch information
james77777778 authored Nov 25, 2024
1 parent 0078f24 commit bef0a9e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
40 changes: 19 additions & 21 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def __init__(
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
if output_shape:
if isinstance(output_shape, int):
output_shape = (output_shape,)
try:
output_shape = tuple(output_shape)
except:
raise ValueError(
f"Invalid `output_shape`: {output_shape}. When "
"specified, the `output_shape` should be of type tuple, "
"list, or int."
)
self._output_shape = output_shape
self._flash_attention = flash_attention or is_flash_attention_enabled()
self._kernel_initializer = initializers.get(kernel_initializer)
Expand Down Expand Up @@ -176,9 +187,8 @@ def dropout(self):
def use_bias(self):
return self._use_bias

@property
def output_shape(self):
return self._output_shape
# Avoid exposing `output_shape` as it may conflict with `Functional` and
# `Sequential` models when calling `summary()`.

@property
def attention_axes(self):
Expand Down Expand Up @@ -343,14 +353,7 @@ def _make_output_dense(self, query_shape, common_kwargs, name=None):
"""
query_rank = len(query_shape)
if self._output_shape:
if isinstance(self._output_shape, (tuple, list)):
output_shape = self._output_shape
elif isinstance(self._output_shape, int):
output_shape = [self._output_shape]
else:
raise ValueError(
f"Invalid output_shape type: {self._output_shape}"
)
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
Expand Down Expand Up @@ -664,26 +667,21 @@ def compute_output_shape(
value_shape,
key_shape=None,
):
query_shape = tuple(query_shape)
value_shape = tuple(value_shape)
if key_shape is None:
key_shape = value_shape
else:
key_shape = tuple(key_shape)

if value_shape[1:-1] != key_shape[1:-1]:
raise ValueError(
"All dimensions of `value` and `key`, except the last one, "
f"must be equal. Received: value_shape={value_shape} and "
f"key_shape={key_shape}"
)

if self._output_shape:
if isinstance(self._output_shape, (tuple, list)):
return query_shape[:-1] + tuple(self._output_shape)
elif isinstance(self._output_shape, int):
return query_shape[:-1] + (self._output_shape,)
else:
raise ValueError(
f"Invalid output_shape type: {self._output_shape}"
)

query_shape = query_shape[:-1] + self._output_shape
return query_shape

def compute_output_spec(
Expand Down
12 changes: 12 additions & 0 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def test_compute_output_shape(
)
self.assertEqual(output.shape, comp_output_shape)

# Test shapes as lists.
comp_output_shape = layer.compute_output_shape(
list(query_shape),
list(value_shape),
list(key_shape) if key_shape is not None else None,
)
self.assertEqual(output.shape, comp_output_shape)

@parameterized.named_parameters(
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), (2,)),
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),
Expand Down Expand Up @@ -634,3 +642,7 @@ def test_multi_head_attention_output_shape_as_tuple(self):
assert output.shape == (2, 4, 8, 8), (
f"Expected shape (2, 4, 8, 8)," f" got {output.shape}"
)

def test_multi_head_attention_output_shape_error(self):
with self.assertRaisesRegex(ValueError, r"Invalid `output_shape`"):
layers.MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8.0)
2 changes: 1 addition & 1 deletion keras/src/utils/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def format_shape(shape):
else:
try:
if hasattr(layer, "output_shape"):
output_shapes = layer.output_shape
output_shapes = format_shape(layer.output_shape)
else:
outputs = layer.compute_output_shape(**layer._build_shapes_dict)
output_shapes = tree.map_shape_structure(
Expand Down
26 changes: 26 additions & 0 deletions keras/src/utils/summary_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,29 @@ def print_to_variable(text, line_break=False):
self.assertIn("Total params: 12", summary_content)
self.assertIn("Trainable params: 12", summary_content)
self.assertIn("Non-trainable params: 0", summary_content)

def test_print_model_summary_with_mha(self):
# In Keras <= 3.6, MHA exposes `output_shape` property which breaks this
# test.
class MyModel(models.Model):
def __init__(self):
super().__init__()
self.mha = layers.MultiHeadAttention(2, 2, output_shape=(4,))

def call(self, inputs):
return self.mha(inputs, inputs, inputs)

model = MyModel()
model(np.ones((1, 2, 2)))

summary_content = []

def print_to_variable(text, line_break=False):
summary_content.append(text)

summary_utils.print_summary(model, print_fn=print_to_variable)
summary_content = "\n".join(summary_content)
self.assertIn("(1, 2, 4)", summary_content) # mha
self.assertIn("Total params: 56", summary_content)
self.assertIn("Trainable params: 56", summary_content)
self.assertIn("Non-trainable params: 0", summary_content)

0 comments on commit bef0a9e

Please sign in to comment.