diff --git a/models/anchor_utils.py b/models/anchor_utils.py index fddb37f2..ebe15107 100644 --- a/models/anchor_utils.py +++ b/models/anchor_utils.py @@ -17,8 +17,13 @@ def __init__( self.strides = strides self.anchor_grids = anchor_grids - def set_wh_weights(self, grid_sizes, dtype, device): - # type: (List[List[int]], int, Device) -> Tensor # noqa: F821 + def set_wh_weights( + self, + grid_sizes: List[List[int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + wh_weights = [] for size, stride in zip(grid_sizes, self.strides): @@ -31,8 +36,13 @@ def set_wh_weights(self, grid_sizes, dtype, device): return torch.cat(wh_weights) - def set_xy_weights(self, grid_sizes, dtype, device): - # type: (List[List[int]], int, Device) -> Tensor # noqa: F821 + def set_xy_weights( + self, + grid_sizes: List[List[int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + xy_weights = [] for size, anchor_grid in zip(grid_sizes, self.anchor_grids): @@ -45,8 +55,12 @@ def set_xy_weights(self, grid_sizes, dtype, device): return torch.cat(xy_weights) - def grid_anchors(self, grid_sizes, device): - # type: (List[List[int]], Device) -> Tensor # noqa: F821 + def grid_anchors( + self, + grid_sizes: List[List[int]], + device: torch.device = torch.device("cpu"), + ) -> Tensor: + anchors = [] for size in grid_sizes: