Skip to content

Commit

Permalink
Enhance ShufflenetV2 (#892)
Browse files Browse the repository at this point in the history
* Enhance ShufflenetV2

Class shufflenetv2 receives `stages_repeats` and `stages_out_channels` arguments.

* remove explicit num_classes argument from utility functions
  • Loading branch information
barrh authored and fmassa committed May 8, 2019
1 parent dc3ac29 commit 43ab2fe
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ def forward(self, x):


class ShuffleNetV2(nn.Module):
def __init__(self, stage_out_channels, num_classes=1000):
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__()

self.stage_out_channels = stage_out_channels
input_channels = 3
output_channels = self.stage_out_channels[0]
if len(stages_repeats) != 3:
raise ValueError('expected stages_repeats as list of 3 positive ints')
if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints')
self._stage_out_channels = stages_out_channels

input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
Expand All @@ -101,16 +105,15 @@ def __init__(self, stage_out_channels, num_classes=1000):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
stage_repeats = [4, 8, 4]
for name, repeats, output_channels in zip(
stage_names, stage_repeats, self.stage_out_channels[1:]):
stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels

output_channels = self.stage_out_channels[-1]
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
Expand All @@ -131,8 +134,8 @@ def forward(self, x):
return x


def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
model = ShuffleNetV2(*args, **kwargs)

if pretrained:
model_url = model_urls[arch]
Expand All @@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):


def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)


def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)


def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)


def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)

0 comments on commit 43ab2fe

Please sign in to comment.