From c82bd84069fca53dcdf758d7f803d2e80b086725 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Aug 2023 10:49:04 +0800 Subject: [PATCH 1/3] change the order of condition, so that torch.fx can trace these modules --- mmcv/cnn/bricks/wrappers.py | 14 +++++++------- tests/test_cnn/test_wrappers.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py index 07eb04ee32..fc98c35584 100644 --- a/mmcv/cnn/bricks/wrappers.py +++ b/mmcv/cnn/bricks/wrappers.py @@ -41,7 +41,7 @@ def backward(ctx, grad: torch.Tensor) -> tuple: class Conv2d(nn.Conv2d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, self.dilation): @@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Conv3d(nn.Conv3d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, self.dilation): @@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvTranspose2d(nn.ConvTranspose2d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, @@ -106,7 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvTranspose3d(nn.ConvTranspose3d): def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, self.padding, self.stride, @@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d): def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), @@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d): def forward(self, x: torch.Tensor) -> torch.Tensor: # PyTorch 1.9 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), _triple(self.padding), @@ -164,7 +164,7 @@ class Linear(torch.nn.Linear): def forward(self, x: torch.Tensor) -> torch.Tensor: # empty tensor forward of Linear layer is supported in Pytorch 1.6 - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): + if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0: out_shape = [x.shape[0], self.out_features] empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 02e0f13cd7..43f45c9a48 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -374,3 +374,25 @@ def test_nn_op_forward_called(): wrapper = Linear(3, 3) wrapper(x_normal) nn_module_forward.assert_called_with(x_normal) + + +def test_fx_compatibility(): + try: + from torch import fx + + # ensure the fx trace can pass the network + for Net in (MaxPool2d, MaxPool3d): + net = Net(1) + gm_module = fx.symbolic_trace(net) + print(gm_module.code) + for Net in (Linear, ): + net = Net(1, 1) + gm_module = fx.symbolic_trace(net) + print(gm_module.code) + for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): + net = Net(1, 1, 1) + gm_module = fx.symbolic_trace(net) + print(gm_module.code) + except ImportError: + # torch.fx might not be available + pass From 616fc7260778948c2f6656103c759ae205b30b77 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Aug 2023 12:02:25 +0800 Subject: [PATCH 2/3] guard the test of torch.fx on MaxPool2d, MaxPool3d under torch 1.10 --- tests/test_cnn/test_wrappers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 43f45c9a48..94c51ce15c 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -376,6 +376,7 @@ def test_nn_op_forward_called(): nn_module_forward.assert_called_with(x_normal) +@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) def test_fx_compatibility(): try: from torch import fx From fc883ec85f8bbd08fe934d1fa19a22b7051b3c04 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Aug 2023 16:53:03 +0800 Subject: [PATCH 3/3] modify skip test conditions --- tests/test_cnn/test_wrappers.py | 37 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 94c51ce15c..8c76ccbdd4 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -4,6 +4,8 @@ import pytest import torch import torch.nn as nn +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) @@ -376,24 +378,19 @@ def test_nn_op_forward_called(): nn_module_forward.assert_called_with(x_normal) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.10'), + reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9') def test_fx_compatibility(): - try: - from torch import fx - - # ensure the fx trace can pass the network - for Net in (MaxPool2d, MaxPool3d): - net = Net(1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Linear, ): - net = Net(1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): - net = Net(1, 1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - except ImportError: - # torch.fx might not be available - pass + from torch import fx + + # ensure the fx trace can pass the network + for Net in (MaxPool2d, MaxPool3d): + net = Net(1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Linear, ): + net = Net(1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): + net = Net(1, 1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841