Skip to content

Commit

Permalink
add 3d
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Jun 24, 2021
1 parent 6b95a32 commit 6ddfa29
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/test_cnn/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_linear(in_w, in_h, in_feature, out_feature):
wrapper(x_empty)


@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 9))
@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
def test_nn_op_forward_called():

for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
Expand All @@ -347,6 +347,20 @@ def test_nn_op_forward_called():
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)

for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)

# non-randn input
x_normal = torch.randn(1, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)

with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3)
Expand Down

0 comments on commit 6ddfa29

Please sign in to comment.