From 1706921b7998ec9df35969f55d38c6f4badcb1f4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Nov 2020 17:12:04 +0000 Subject: [PATCH] Change reshape to support empty batches. (#3031) --- torchvision/models/detection/roi_heads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 90831b3d4bb..d67d5856f76 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -40,7 +40,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): sampled_pos_inds_subset = torch.where(labels > 0)[0] labels_pos = labels[sampled_pos_inds_subset] N, num_classes = class_logits.shape - box_regression = box_regression.reshape(N, -1, 4) + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) box_loss = det_utils.smooth_l1_loss( box_regression[sampled_pos_inds_subset, labels_pos],