Skip to content

Commit

Permalink
Fixed typing in constructors of models submodules (pytorch#2875)
Browse files Browse the repository at this point in the history
* fix: Fixed constructor typing in models._utils

* fix: Fixed constructor typing in models.alexnet

* fix: Fixed constructor typing in models.mnasnet

* fix: Fixed constructor typing in models.squeezenet
  • Loading branch information
frgfm authored and bryant1410 committed Nov 22, 2020
1 parent dc26f0a commit 4dcb5d8
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchvision/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
"return_layers": Dict[str, str],
}

def __init__(self, model: nn.Module, return_layers: Dict[str, str]):
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class AlexNet(nn.Module):

def __init__(self, num_classes: int = 1000):
def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
):
) -> None:
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
):
) -> None:
super(MNASNet, self).__init__()
assert alpha > 0.0
self.alpha = alpha
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int
):
) -> None:
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
Expand All @@ -46,7 +46,7 @@ def __init__(
self,
version: str = '1_0',
num_classes: int = 1000
):
) -> None:
super(SqueezeNet, self).__init__()
self.num_classes = num_classes
if version == '1_0':
Expand Down

0 comments on commit 4dcb5d8

Please sign in to comment.