Skip to content

Commit

Permalink
Add inductor support for conv3d transpose (#129458)
Browse files Browse the repository at this point in the history
This PR is to add Conv3d Transpose support in inductor. Basicly reuse and expand Conv2d Transpose and unit tests to Conv3d Transpose.

Pull Request resolved: #129458
Approved by: https://github.com/jgong5, https://github.com/jansel
  • Loading branch information
yanbing-j authored and pytorchmergebot committed Jun 27, 2024
1 parent 9b5b93c commit 5ee893a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
25 changes: 17 additions & 8 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,20 +319,28 @@ static ideep::tensor::desc get_conv_transpose_expected_weights_desc(
}
}

static Tensor mkldnn_reorder_conv_transpose2d_weight(
static Tensor mkldnn_reorder_conv_transpose_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
TORCH_CHECK(
(self.dim() == 4 || self.dim() == 5),
"mkldnn_reorder_conv_transpose_weight only supports conv_transpose2d and conv_transpose3d");
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv_transpose2d_weight");
const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2);
mkldnn_check_low_precision(
self.scalar_type(), "mkldnn_reorder_conv_transpose_weight");
int64_t pdim = self.dim() - 2;
const auto padding_expanded =
expand_param_if_needed(padding, "padding", pdim);
const auto stride_expanded = expand_param_if_needed(stride, "stride", pdim);
const auto dilation_expanded =
expand_param_if_needed(dilation, "dilation", pdim);
const auto output_padding_expanded =
expand_param_if_needed(output_padding, "output_padding", pdim);

ideep::dims src_dims = ideep::dims();
bool is_channels_last = false;
Expand All @@ -341,7 +349,8 @@ static Tensor mkldnn_reorder_conv_transpose2d_weight(
src_dims = input_size.value().vec();
// if has input size, we always use channels last.
is_channels_last = true;
memory_format = at::MemoryFormat::ChannelsLast;
memory_format = self.dim() == 4 ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::ChannelsLast3d;
}

auto self_ = self.contiguous(memory_format);
Expand Down Expand Up @@ -532,7 +541,7 @@ static Tensor get_mkldnn_serialized_md(const Tensor& self) {
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
TORCH_FN(mkldnn_reorder_conv_transpose2d_weight));
TORCH_FN(mkldnn_reorder_conv_transpose_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_linear_weight"),
TORCH_FN(mkldnn_reorder_linear_weight));
Expand Down
40 changes: 30 additions & 10 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,24 +455,28 @@ def forward(self, x):
# 1 kernel for "to_lowp", 2 kernels for unary ops
self.assertEqual(metrics.generated_kernel_count, 3)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv_transpose2d_unary(self):
def _test_conv_transpose_unary_base(self, dim=4):
assert dim == 4 or dim == 5

class M(torch.nn.Module):
def __init__(
self,
unary_fn,
**kwargs,
):
super().__init__()
self.conv_transpose2d = torch.nn.ConvTranspose2d(
3, 16, 3, stride=2, padding=1
)
if dim == 4:
self.conv_transpose = torch.nn.ConvTranspose2d(
3, 16, 3, stride=2, padding=1
)
else:
self.conv_transpose = torch.nn.ConvTranspose3d(
3, 16, 3, stride=2, padding=1
)
self.unary_fn = unary_fn

def forward(self, x):
x = self.conv_transpose2d(x)
x = self.conv_transpose(x)
return self.unary_fn(x)

dtypes = [
Expand All @@ -483,15 +487,19 @@ def forward(self, x):
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)

cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
options = itertools.product(
unary_list,
[torch.contiguous_format, torch.channels_last],
[torch.contiguous_format, cl_format],
dtypes,
)

for unary_fn, memory_format, dtype in options:
metrics.reset()
x_shape = (1, 3, 28, 28)
if dim == 4:
x_shape = (1, 3, 28, 28)
else:
x_shape = (1, 3, 17, 28, 28)
mod = M(unary_fn).eval()

v = torch.randn(x_shape, dtype=torch.float32).to(
Expand All @@ -509,6 +517,18 @@ def forward(self, x):
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv_transpose2d_unary_cpu(self):
self._test_conv_transpose_unary_base(dim=4)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv_transpose3d_unary_cpu(self):
self._test_conv_transpose_unary_base(dim=5)

def _test_conv_binary_base(self, dim=4):
assert dim == 4 or dim == 5

Expand Down

0 comments on commit 5ee893a

Please sign in to comment.