From a023ef65c32150f4ea713ab7ec5b6846ebfd69b4 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 23 Jun 2021 20:59:23 +0800 Subject: [PATCH 1/3] empty tensor inference backward continity --- mmcv/cnn/bricks/wrappers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py index a464f86dc1..6e125b41ca 100644 --- a/mmcv/cnn/bricks/wrappers.py +++ b/mmcv/cnn/bricks/wrappers.py @@ -128,8 +128,8 @@ def forward(self, x): class MaxPool2d(nn.MaxPool2d): def forward(self, x): - # PyTorch 1.7 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), @@ -146,8 +146,8 @@ def forward(self, x): class MaxPool3d(nn.MaxPool3d): def forward(self, x): - # PyTorch 1.7 does not support empty tensor inference yet - if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), _triple(self.padding), From 6b95a3278d5a0ab0e5d21fca5c7c30a4ecced5f1 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Thu, 24 Jun 2021 10:28:59 +0800 Subject: [PATCH 2/3] update --- tests/test_cnn/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 326cfd2d2a..c6ed7fb325 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature): wrapper(x_empty) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 8)) +@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 9)) def test_nn_op_forward_called(): for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: From 6ddfa2925e6b4a2abb75f62955b08fdb6acfd4da Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Thu, 24 Jun 2021 15:26:45 +0800 Subject: [PATCH 3/3] add 3d --- tests/test_cnn/test_wrappers.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index c6ed7fb325..ffc933fec2 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature): wrapper(x_empty) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 9)) +@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) def test_nn_op_forward_called(): for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: @@ -347,6 +347,20 @@ def test_nn_op_forward_called(): wrapper(x_normal) nn_module_forward.assert_called_with(x_normal) + for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']: + with patch(f'torch.nn.{m}.forward') as nn_module_forward: + # randn input + x_empty = torch.randn(0, 3, 10, 10, 10) + wrapper = eval(m)(3, 2, 1) + wrapper(x_empty) + nn_module_forward.assert_called_with(x_empty) + + # non-randn input + x_normal = torch.randn(1, 3, 10, 10, 10) + wrapper = eval(m)(3, 2, 1) + wrapper(x_normal) + nn_module_forward.assert_called_with(x_normal) + with patch('torch.nn.Linear.forward') as nn_module_forward: # randn input x_empty = torch.randn(0, 3)