From 66f2922028754c6a7af648e8d39c62a7d7a49659 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 13 Jan 2020 18:31:33 +0100 Subject: [PATCH] Fix for AnchorGenerator when device switch happen (#1745) * Fix AnchorGenerator if moving from one device to another * Fixes for the test --- test/test_models.py | 22 ++++++++++++++++++++++ torchvision/models/detection/rpn.py | 7 ++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 14c70175dc0..1c0b4892209 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -226,6 +226,28 @@ def test_fasterrcnn_double(self): self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) + @unittest.skipIf(not torch.cuda.is_available(), 'needs GPU') + def test_fasterrcnn_switch_devices(self): + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model.cuda() + model.eval() + input_shape = (3, 300, 300) + x = torch.rand(input_shape, device='cuda') + model_input = [x] + out = model(model_input) + self.assertIs(model_input[0], x) + self.assertEqual(len(out), 1) + self.assertTrue("boxes" in out[0]) + self.assertTrue("scores" in out[0]) + self.assertTrue("labels" in out[0]) + # now switch to cpu and make sure it works + model.cpu() + x = x.cpu() + out_cpu = model([x]) + self.assertTrue("boxes" in out_cpu[0]) + self.assertTrue("scores" in out_cpu[0]) + self.assertTrue("labels" in out_cpu[0]) + for model_name in get_available_classification_models(): # for-loop bodies don't define scopes, so we have to save the variables diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index f9f8cadefc3..f1c720bf748 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -90,7 +90,12 @@ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="c def set_cell_anchors(self, dtype, device): # type: (int, Device) -> None # noqa: F821 if self.cell_anchors is not None: - return + cell_anchors = self.cell_anchors + assert cell_anchors is not None + # suppose that all anchors have the same device + # which is a valid assumption in the current state of the codebase + if cell_anchors[0].device == device: + return cell_anchors = [ self.generate_anchors(