Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added annotation typing to squeezenet #2865

Merged
merged 2 commits into from
Oct 22, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.init as init
from .utils import load_state_dict_from_url
from typing import Any

__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']

Expand All @@ -13,8 +14,13 @@

class Fire(nn.Module):

def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
def __init__(
self,
inplanes: int,
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int
):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
Expand All @@ -26,7 +32,7 @@ def __init__(self, inplanes, squeeze_planes,
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.squeeze_activation(self.squeeze(x))
return torch.cat([
self.expand1x1_activation(self.expand1x1(x)),
Expand All @@ -36,7 +42,11 @@ def forward(self, x):

class SqueezeNet(nn.Module):

def __init__(self, version='1_0', num_classes=1000):
def __init__(
self,
version: str = '1_0',
num_classes: int = 1000
):
super(SqueezeNet, self).__init__()
self.num_classes = num_classes
if version == '1_0':
Expand Down Expand Up @@ -96,13 +106,13 @@ def __init__(self, version='1_0', num_classes=1000):
if m.bias is not None:
init.constant_(m.bias, 0)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.classifier(x)
return torch.flatten(x, 1)


def _squeezenet(version, pretrained, progress, **kwargs):
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
model = SqueezeNet(version, **kwargs)
if pretrained:
arch = 'squeezenet' + version
Expand All @@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs):
return model


def squeezenet1_0(pretrained=False, progress=True, **kwargs):
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Expand All @@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs):
return _squeezenet('1_0', pretrained, progress, **kwargs)


def squeezenet1_1(pretrained=False, progress=True, **kwargs):
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
Expand Down