Skip to content

Commit

Permalink
Commented AnchorGenerator (#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
SohamTamba authored Mar 10, 2020
1 parent 7d1cd1d commit 2611f5c
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
self._cache = {}

# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
Expand Down Expand Up @@ -111,6 +113,8 @@ def set_cell_anchors(self, dtype, device):
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
anchors = []
Expand All @@ -127,6 +131,8 @@ def grid_anchors(self, grid_sizes, strides):
stride_width = torch.tensor(stride_width, dtype=torch.float32)
stride_height = torch.tensor(stride_height, dtype=torch.float32)
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
Expand All @@ -138,6 +144,8 @@ def grid_anchors(self, grid_sizes, strides):
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
Expand All @@ -158,6 +166,7 @@ def forward(self, image_list, feature_maps):
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]

dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
Expand Down

0 comments on commit 2611f5c

Please sign in to comment.