From 07fb8ba7fad7b5b458ff862919825df4e6f60b52 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 8 Apr 2021 14:58:21 +0000 Subject: [PATCH] Add missing device info. (#3651) --- torchvision/models/detection/retinanet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index f34db4ce970..c6fe8856a01 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -386,7 +386,8 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64)) + matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, + device=anchors_per_image.device)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)