From 2611f5cc79230f3729393279093775d08bc52718 Mon Sep 17 00:00:00 2001 From: Soham Tamba Date: Tue, 10 Mar 2020 09:01:31 -0400 Subject: [PATCH] Commented AnchorGenerator (#1941) --- torchvision/models/detection/rpn.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index f1c720bf748..093db4f6fe6 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -74,6 +74,8 @@ def __init__( self._cache = {} # TODO: https://github.com/pytorch/pytorch/issues/26792 + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): # type: (List[int], List[float], int, Device) # noqa: F821 scales = torch.as_tensor(scales, dtype=dtype, device=device) @@ -111,6 +113,8 @@ def set_cell_anchors(self, dtype, device): def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. def grid_anchors(self, grid_sizes, strides): # type: (List[List[int]], List[List[int]]) anchors = [] @@ -127,6 +131,8 @@ def grid_anchors(self, grid_sizes, strides): stride_width = torch.tensor(stride_width, dtype=torch.float32) stride_height = torch.tensor(stride_height, dtype=torch.float32) device = base_anchors.device + + # For output anchor, compute [x_center, y_center, x_center, y_center] shifts_x = torch.arange( 0, grid_width, dtype=torch.float32, device=device ) * stride_width @@ -138,6 +144,8 @@ def grid_anchors(self, grid_sizes, strides): shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + # For every (base anchor, output anchor) pair, + # offset each zero-centered base anchor by the center of the output anchor. anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) ) @@ -158,6 +166,7 @@ def forward(self, image_list, feature_maps): grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes] + dtype, device = feature_maps[0].dtype, feature_maps[0].device self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)