diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 5726cea2a22..c56bab30bbd 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -84,13 +84,17 @@ def forward(self, x): class ShuffleNetV2(nn.Module): - def __init__(self, stage_out_channels, num_classes=1000): + def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): super(ShuffleNetV2, self).__init__() - self.stage_out_channels = stage_out_channels - input_channels = 3 - output_channels = self.stage_out_channels[0] + if len(stages_repeats) != 3: + raise ValueError('expected stages_repeats as list of 3 positive ints') + if len(stages_out_channels) != 5: + raise ValueError('expected stages_out_channels as list of 5 positive ints') + self._stage_out_channels = stages_out_channels + input_channels = 3 + output_channels = self._stage_out_channels[0] self.conv1 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), @@ -101,16 +105,15 @@ def __init__(self, stage_out_channels, num_classes=1000): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] - stage_repeats = [4, 8, 4] for name, repeats, output_channels in zip( - stage_names, stage_repeats, self.stage_out_channels[1:]): + stage_names, stages_repeats, self._stage_out_channels[1:]): seq = [InvertedResidual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(InvertedResidual(output_channels, output_channels, 1)) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels - output_channels = self.stage_out_channels[-1] + output_channels = self._stage_out_channels[-1] self.conv5 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channels), @@ -131,8 +134,8 @@ def forward(self, x): return x -def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): - model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs) +def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): + model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] @@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, + [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, + [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, + [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, + [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)