From f00c39203fff861b8e1385bdd4a00d5390faa9c4 Mon Sep 17 00:00:00 2001 From: barrh Date: Fri, 12 Apr 2019 05:26:51 +0300 Subject: [PATCH 1/3] Add ShuffleNet v2 Added 4 configurations: x0.5, x1, x1.5, x2 Add 2 pretrained models: x0.5, x1 --- docs/source/models.rst | 10 ++ torchvision/models/__init__.py | 1 + torchvision/models/shufflenetv2.py | 184 +++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 torchvision/models/shufflenetv2.py diff --git a/docs/source/models.rst b/docs/source/models.rst index 308ba75481b..66bb60e2004 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -11,6 +11,7 @@ architectures: - `DenseNet`_ - `Inception`_ v3 - `GoogLeNet`_ +- `ShuffleNet`_ v2 You can construct a model with random weights by calling its constructor: @@ -24,6 +25,7 @@ You can construct a model with random weights by calling its constructor: densenet = models.densenet161() inception = models.inception_v3() googlenet = models.googlenet() + shufflenet = models.shufflenetv2() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -38,6 +40,7 @@ These can be constructed by passing ``pretrained=True``: densenet = models.densenet161(pretrained=True) inception = models.inception_v3(pretrained=True) googlenet = models.googlenet(pretrained=True) + shufflenet = models.shufflenetv2(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -88,6 +91,7 @@ Densenet-201 22.80 6.43 Densenet-161 22.35 6.20 Inception v3 22.55 6.44 GoogleNet 30.22 10.47 +ShuffleNet V2 30.64 11.68 ================================ ============= ============= @@ -98,6 +102,7 @@ GoogleNet 30.22 10.47 .. _DenseNet: https://arxiv.org/abs/1608.06993 .. _Inception: https://arxiv.org/abs/1512.00567 .. _GoogLeNet: https://arxiv.org/abs/1409.4842 +.. _ShuffleNet: https://arxiv.org/abs/1807.11164 .. currentmodule:: torchvision.models @@ -152,3 +157,8 @@ GoogLeNet .. autofunction:: googlenet +ShuffleNet v2 +------------- + +.. autofunction:: shufflenet + diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 4b5e8e657e0..727aed44dfb 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -6,3 +6,4 @@ from .densenet import * from .googlenet import * from .mobilenet import * +from .shufflenetv2 import * diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py new file mode 100644 index 00000000000..af3cab01ecd --- /dev/null +++ b/torchvision/models/shufflenetv2.py @@ -0,0 +1,184 @@ +import functools + +import torch +import torch.nn as nn + +__all__ = ['ShuffleNetV2', 'shufflenetv2', + 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', + 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] + +model_urls = { + 'shufflenetv2_x0.5': 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt', + 'shufflenetv2_x1.0': 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt', + 'shufflenetv2_x1.5': None, + 'shufflenetv2_x2.0': None, +} + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, + channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride): + super(InvertedResidual, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError('illegal stride value') + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features<<1) + + pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False) + dw_conv33 = functools.partial(self.depthwise_conv, + kernel_size=3, stride=self.stride, padding=1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + dw_conv33(inp, inp), + nn.BatchNorm2d(inp), + pw_conv11(inp, branch_features), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + + self.branch2 = nn.Sequential( + pw_conv11(inp if (self.stride > 1) else branch_features, branch_features), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + dw_conv33(branch_features, branch_features), + nn.BatchNorm2d(branch_features), + pw_conv11(branch_features, branch_features), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + def __init__(self, num_classes=1000, input_size=224, width_mult=1): + super(ShuffleNetV2, self).__init__() + + try: + self.stage_out_channels = self._getStages(float(width_mult)) + except KeyError: + raise ValueError('width_mult {} is not supported'.format(width_mult)) + + input_channels = 3 + output_channels = self.stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + stage_repeats = [4, 8, 4] + for name, repeats, output_channels in zip( + stage_names, stage_repeats, self.stage_out_channels[1:]): + seq = [InvertedResidual(input_channels, output_channels, 2)] + for i in range(repeats-1): + seq.append(InvertedResidual(output_channels, output_channels, 1)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + + output_channels = self.stage_out_channels[-1] + self.conv5 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + + if (input_size % 32): + raise ValueError('illegal input_size') + self.globalpool = nn.AvgPool2d(int(input_size/32)) + + # expected ifm size is: channels x 1 x 1 + self.fc = nn.Linear(self.stage_out_channels[-1], num_classes) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + x = self.globalpool(x) + x = x.view(-1, self.stage_out_channels[-1]) + x = self.fc(x) + return x + + @staticmethod + def _getStages(mult): + stages = { + '0.5': [24, 48, 96, 192, 1024], + '1.0': [24, 116, 232, 464, 1024], + '1.5': [24, 176, 352, 704, 1024], + '2.0': [24, 244, 488, 976, 2048], + } + return stages[str(mult)] + + +def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=1, **kwargs): + model = ShuffleNetV2(num_classes=num_classes, input_size=input_size, width_mult=width_mult) + + if pretrained: + # change width_mult to float + if isinstance(width_mult, int): + width_mult = float(width_mult) + model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)])) + try: + model_url = model_urls[model_type.lower()] + except KeyError: + raise ValueError('model {} is not support'.format(model_type)) + if model_url is None: + raise NotImplementedError('pretrained {} is not supported'.format(model_type)) + model.load_state_dict(torch.utils.model_zoo.load_url(model_url)) + + return model + + +def shufflenetv2_x0_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): + return shufflenetv2(pretrained, num_classes, input_size, 0.5) + + +def shufflenetv2_x1_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): + return shufflenetv2(pretrained, num_classes, input_size, 1) + + +def shufflenetv2_x1_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): + return shufflenetv2(pretrained, num_classes, input_size, 1.5) + + +def shufflenetv2_x2_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): + return shufflenetv2(pretrained, num_classes, input_size, 2) From de95f2ba44903b7c3fc4785a86a330c2801ab6cd Mon Sep 17 00:00:00 2001 From: barrh Date: Fri, 12 Apr 2019 12:48:43 +0300 Subject: [PATCH 2/3] fix lint --- torchvision/models/shufflenetv2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index af3cab01ecd..8376eec9398 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -8,8 +8,10 @@ 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] model_urls = { - 'shufflenetv2_x0.5': 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt', - 'shufflenetv2_x1.0': 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt', + 'shufflenetv2_x0.5': + 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt', + 'shufflenetv2_x1.0': + 'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt', 'shufflenetv2_x1.5': None, 'shufflenetv2_x2.0': None, } @@ -40,7 +42,7 @@ def __init__(self, inp, oup, stride): self.stride = stride branch_features = oup // 2 - assert (self.stride != 1) or (inp == branch_features<<1) + assert (self.stride != 1) or (inp == branch_features << 1) pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False) dw_conv33 = functools.partial(self.depthwise_conv, @@ -105,9 +107,9 @@ def __init__(self, num_classes=1000, input_size=224, width_mult=1): stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] stage_repeats = [4, 8, 4] for name, repeats, output_channels in zip( - stage_names, stage_repeats, self.stage_out_channels[1:]): + stage_names, stage_repeats, self.stage_out_channels[1:]): seq = [InvertedResidual(input_channels, output_channels, 2)] - for i in range(repeats-1): + for i in range(repeats - 1): seq.append(InvertedResidual(output_channels, output_channels, 1)) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels @@ -121,7 +123,7 @@ def __init__(self, num_classes=1000, input_size=224, width_mult=1): if (input_size % 32): raise ValueError('illegal input_size') - self.globalpool = nn.AvgPool2d(int(input_size/32)) + self.globalpool = nn.AvgPool2d(int(input_size / 32)) # expected ifm size is: channels x 1 x 1 self.fc = nn.Linear(self.stage_out_channels[-1], num_classes) @@ -141,7 +143,7 @@ def forward(self, x): @staticmethod def _getStages(mult): stages = { - '0.5': [24, 48, 96, 192, 1024], + '0.5': [24, 48, 96, 192, 1024], '1.0': [24, 116, 232, 464, 1024], '1.5': [24, 176, 352, 704, 1024], '2.0': [24, 244, 488, 976, 2048], From 4c8abb69b471d16a13dd3e29c22d2c6a530c9600 Mon Sep 17 00:00:00 2001 From: barrh Date: Mon, 15 Apr 2019 18:37:43 +0300 Subject: [PATCH 3/3] Change globalpool to torch.mean() call --- torchvision/models/shufflenetv2.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 8376eec9398..e53633d21d3 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -121,12 +121,7 @@ def __init__(self, num_classes=1000, input_size=224, width_mult=1): nn.ReLU(inplace=True), ) - if (input_size % 32): - raise ValueError('illegal input_size') - self.globalpool = nn.AvgPool2d(int(input_size / 32)) - - # expected ifm size is: channels x 1 x 1 - self.fc = nn.Linear(self.stage_out_channels[-1], num_classes) + self.fc = nn.Linear(output_channels, num_classes) def forward(self, x): x = self.conv1(x) @@ -135,8 +130,7 @@ def forward(self, x): x = self.stage3(x) x = self.stage4(x) x = self.conv5(x) - x = self.globalpool(x) - x = x.view(-1, self.stage_out_channels[-1]) + x = x.mean([2, 3]) # globalpool x = self.fc(x) return x