diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 90831b3d4bb..d67d5856f76 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -40,7 +40,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): sampled_pos_inds_subset = torch.where(labels > 0)[0] labels_pos = labels[sampled_pos_inds_subset] N, num_classes = class_logits.shape - box_regression = box_regression.reshape(N, -1, 4) + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) box_loss = det_utils.smooth_l1_loss( box_regression[sampled_pos_inds_subset, labels_pos],