From 60e6981d77ac050dc37828792f0deabca0d69cc1 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 11 Jan 2022 14:22:44 +0800 Subject: [PATCH] [Fix] Fix the unit test of correlation op --- tests/test_ops/test_correlation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_ops/test_correlation.py b/tests/test_ops/test_correlation.py index 6b75a9f38b..6cf5f9f72d 100644 --- a/tests/test_ops/test_correlation.py +++ b/tests/test_ops/test_correlation.py @@ -30,10 +30,13 @@ def _test_correlation(self, dtype=torch.float): out = layer(input1, input2) out.backward(torch.ones_like(out)) - gt_out = torch.tensor(_gt_out, dtype=dtype) - assert_equal_tensor(out.cpu(), gt_out) - assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu()) - assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu()) + # `eq_cpu` is not implemented for 'Half' in torch1.5.0, + # so we need to make a comparison for cuda tensor + # rather than cpu tensor + gt_out = torch.tensor(_gt_out, dtype=dtype).cuda() + assert_equal_tensor(out, gt_out) + assert_equal_tensor(input1.grad.detach(), input2) + assert_equal_tensor(input2.grad.detach(), input1) @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support')