diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index ee9f2a1ab5a..9a63df1c943 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,7 +1,7 @@ import torch from torch import nn, Tensor -from typing import List, Optional +from typing import List from .image_list import ImageList @@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module): """ __annotations__ = { - "cell_anchors": Optional[List[torch.Tensor]], + "cell_anchors": List[torch.Tensor], } def __init__( @@ -47,7 +47,8 @@ def __init__( self.sizes = sizes self.aspect_ratios = aspect_ratios - self.cell_anchors = None + self.cell_anchors = [self.generate_anchors(size, aspect_ratio) + for size, aspect_ratio in zip(sizes, aspect_ratios)] # TODO: https://github.com/pytorch/pytorch/issues/26792 # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. @@ -67,24 +68,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: return base_anchors.round() def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): - if self.cell_anchors is not None: - cell_anchors = self.cell_anchors - assert cell_anchors is not None - # suppose that all anchors have the same device - # which is a valid assumption in the current state of the codebase - if cell_anchors[0].device == device: - return - - cell_anchors = [ - self.generate_anchors( - sizes, - aspect_ratios, - dtype, - device - ) - for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) - ] - self.cell_anchors = cell_anchors + self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) + for cell_anchor in self.cell_anchors] def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] @@ -130,7 +115,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) return anchors def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: - grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] image_size = image_list.tensors.shape[-2:] dtype, device = feature_maps[0].dtype, feature_maps[0].device strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), @@ -138,7 +123,7 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors: List[List[torch.Tensor]] = [] - for i in range(len(image_list.image_sizes)): + for _ in range(len(image_list.image_sizes)): anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] anchors.append(anchors_in_image) anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]