From 9d08346d1296dfac4030d88cb17d4de9376abd98 Mon Sep 17 00:00:00 2001 From: Ailing Date: Fri, 13 Mar 2020 09:33:06 -0700 Subject: [PATCH] Add device to torch.tensor. (#1979) --- torchvision/models/detection/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index dac511452bc..8832b3a534c 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -220,10 +220,10 @@ def decode_single(self, rel_codes, boxes): pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] - pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w - pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h - pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w - pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h + pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w + pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h + pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w + pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) return pred_boxes