From 661769eceef73ba5c3777fe9646efdc3c917f057 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Thu, 22 Oct 2020 17:05:14 +0200 Subject: [PATCH] Added annotation typing to squeezenet (#2865) * style: Added annotation typing for squeezenet * feat: Added typing for kwargs --- torchvision/models/squeezenet.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 964f3ec66da..82448516c03 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.init as init from .utils import load_state_dict_from_url +from typing import Any __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] @@ -13,8 +14,13 @@ class Fire(nn.Module): - def __init__(self, inplanes, squeeze_planes, - expand1x1_planes, expand3x3_planes): + def __init__( + self, + inplanes: int, + squeeze_planes: int, + expand1x1_planes: int, + expand3x3_planes: int + ): super(Fire, self).__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) @@ -26,7 +32,7 @@ def __init__(self, inplanes, squeeze_planes, kernel_size=3, padding=1) self.expand3x3_activation = nn.ReLU(inplace=True) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.squeeze_activation(self.squeeze(x)) return torch.cat([ self.expand1x1_activation(self.expand1x1(x)), @@ -36,7 +42,11 @@ def forward(self, x): class SqueezeNet(nn.Module): - def __init__(self, version='1_0', num_classes=1000): + def __init__( + self, + version: str = '1_0', + num_classes: int = 1000 + ): super(SqueezeNet, self).__init__() self.num_classes = num_classes if version == '1_0': @@ -96,13 +106,13 @@ def __init__(self, version='1_0', num_classes=1000): if m.bias is not None: init.constant_(m.bias, 0) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.classifier(x) return torch.flatten(x, 1) -def _squeezenet(version, pretrained, progress, **kwargs): +def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: model = SqueezeNet(version, **kwargs) if pretrained: arch = 'squeezenet' + version @@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs): return model -def squeezenet1_0(pretrained=False, progress=True, **kwargs): +def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. @@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs): return _squeezenet('1_0', pretrained, progress, **kwargs) -def squeezenet1_1(pretrained=False, progress=True, **kwargs): +def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters