From c1424671448f30dc94cc8d3fc5e2ef00501c8b2d Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 17:54:48 +0000 Subject: [PATCH] fix (frontends)(torch)(convolutional_functions): adding `filter_format` argument to `_conv_transpose` to avoid transposing the weights if they're in the channel_last format --- .../frontends/torch/nn/functional/convolution_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py index 89b4bcf3c7f67..12707287a3f64 100644 --- a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py @@ -46,9 +46,11 @@ def _conv_transpose( output_padding=0, groups=1, dilation=1, + filter_format="channel_first", ): dims = len(input.shape) - 2 - weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) + if filter_format == "channel_first": + weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) for i in range(dims): weight = ivy.flip(weight, axis=i) padding, output_padding, stride, dilation = map( @@ -185,6 +187,7 @@ def conv_transpose1d( output_padding=output_padding, groups=groups, dilation=dilation, + filter_format="channel_first", )