diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index f34db4ce970..c6fe8856a01 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -386,7 +386,8 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64)) + matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, + device=anchors_per_image.device)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)