Skip to content

Commit

Permalink
Added annotation typing to squeezenet (pytorch#2865)
Browse files Browse the repository at this point in the history
* style: Added annotation typing for squeezenet

* feat: Added typing for kwargs
  • Loading branch information
frgfm authored and vfdev-5 committed Dec 4, 2020
1 parent 6c5d163 commit 661769e
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.init as init
from .utils import load_state_dict_from_url
from typing import Any

__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']

Expand All @@ -13,8 +14,13 @@

class Fire(nn.Module):

def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
def __init__(
self,
inplanes: int,
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int
):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
Expand All @@ -26,7 +32,7 @@ def __init__(self, inplanes, squeeze_planes,
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.squeeze_activation(self.squeeze(x))
return torch.cat([
self.expand1x1_activation(self.expand1x1(x)),
Expand All @@ -36,7 +42,11 @@ def forward(self, x):

class SqueezeNet(nn.Module):

def __init__(self, version='1_0', num_classes=1000):
def __init__(
self,
version: str = '1_0',
num_classes: int = 1000
):
super(SqueezeNet, self).__init__()
self.num_classes = num_classes
if version == '1_0':
Expand Down Expand Up @@ -96,13 +106,13 @@ def __init__(self, version='1_0', num_classes=1000):
if m.bias is not None:
init.constant_(m.bias, 0)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.classifier(x)
return torch.flatten(x, 1)


def _squeezenet(version, pretrained, progress, **kwargs):
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
model = SqueezeNet(version, **kwargs)
if pretrained:
arch = 'squeezenet' + version
Expand All @@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs):
return model


def squeezenet1_0(pretrained=False, progress=True, **kwargs):
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Expand All @@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs):
return _squeezenet('1_0', pretrained, progress, **kwargs)


def squeezenet1_1(pretrained=False, progress=True, **kwargs):
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
Expand Down

0 comments on commit 661769e

Please sign in to comment.