Skip to content

Commit

Permalink
fix (frontends)(torch)(convolutional_functions): adding `filter_forma…
Browse files Browse the repository at this point in the history
…t` argument to `_conv_transpose` to avoid transposing the weights if they're in the channel_last format
  • Loading branch information
YushaArif99 committed Sep 18, 2024
1 parent 55095e7 commit c142467
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def _conv_transpose(
output_padding=0,
groups=1,
dilation=1,
filter_format="channel_first",
):
dims = len(input.shape) - 2
weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1))
if filter_format == "channel_first":
weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1))
for i in range(dims):
weight = ivy.flip(weight, axis=i)
padding, output_padding, stride, dilation = map(
Expand Down Expand Up @@ -185,6 +187,7 @@ def conv_transpose1d(
output_padding=output_padding,
groups=groups,
dilation=dilation,
filter_format="channel_first",
)


Expand Down

0 comments on commit c142467

Please sign in to comment.