Skip to content

Commit

Permalink
Fix shape inference bug for conv2dtranspose
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 5, 2023
1 parent 309e688 commit 5949a76
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras_core/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def compute_conv_transpose_padding_args_for_torch(
def _get_output_shape_given_tf_padding(
input_size, kernel_size, strides, padding, output_padding, dilation_rate
):
if input_size is None:
return None

assert padding.lower() in {"valid", "same"}

kernel_size = (kernel_size - 1) * dilation_rate + 1
Expand Down
20 changes: 20 additions & 0 deletions keras_core/layers/convolutional/conv_transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,23 @@ def test_conv1d_transpose_consistency(
# Compare results
kc_res = kc_layer(input)
self.assertAllClose(expected_res, kc_res, atol=1e-5)

@parameterized.product(
kernel_size=list(range(1, 5)),
strides=list(range(1, 5)),
padding=["same", "valid"],
output_padding=[None] + list(range(1, 5)),
)
def test_shape_inference_static_unknown_shape(
self, kernel_size, strides, padding, output_padding
):
x = layers.Input(shape=(None, None, 3))
x = layers.Conv2DTranspose(
filters=2,
kernel_size=kernel_size,
strides=strides,
padding=padding,
output_padding=output_padding,
dilation_rate=1,
)(x)
self.assertEqual(x.shape, (None, None, None, 2))

0 comments on commit 5949a76

Please sign in to comment.