Skip to content

Commit

Permalink
Clean up no longer needed workarorudns (#2261)
Browse files Browse the repository at this point in the history
Co-authored-by: eellison <eellison@fb.com>
  • Loading branch information
eellison and eellison authored May 26, 2020
1 parent 3974cfe commit 5ba57ea
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
19 changes: 6 additions & 13 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@
import torchvision


# TODO: https://github.com/pytorch/pytorch/issues/26727
def zeros_like(tensor, dtype):
# type: (Tensor, int) -> Tensor
return torch.zeros_like(tensor, dtype=dtype, layout=tensor.layout,
device=tensor.device, pin_memory=tensor.is_pinned())


class BalancedPositiveNegativeSampler(object):
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
Expand Down Expand Up @@ -66,15 +59,15 @@ def __call__(self, matched_idxs):
neg_idx_per_image = negative[perm2]

# create binary mask from indices
pos_idx_per_image_mask = zeros_like(
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = zeros_like(
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)

pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1

pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
Expand Down Expand Up @@ -304,8 +297,8 @@ def __call__(self, match_quality_matrix):
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
matches[below_low_threshold] = torch.tensor(self.BELOW_LOW_THRESHOLD)
matches[between_thresholds] = torch.tensor(self.BETWEEN_THRESHOLDS)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS

if self.allow_low_quality_matches:
assert all_matches is not None
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
y = (y - offset_y) * scale_y
y = y.floor().long()

x[x_boundary_inds] = torch.tensor(heatmap_size - 1)
y[y_boundary_inds] = torch.tensor(heatmap_size - 1)
x[x_boundary_inds] = heatmap_size - 1
y[y_boundary_inds] = heatmap_size - 1

valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0
Expand Down Expand Up @@ -584,11 +584,11 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):

# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = torch.tensor(0)
labels_in_image[bg_inds] = 0

# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler

matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,11 @@ def assign_targets_to_anchors(self, anchors, targets):

# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = torch.tensor(0.0)
labels_per_image[bg_indices] = 0.0

# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
labels_per_image[inds_to_discard] = -1.0

labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
Expand Down

0 comments on commit 5ba57ea

Please sign in to comment.