Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ShuffleNet v2 #849

Merged
merged 3 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ architectures:
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2

You can construct a model with random weights by calling its constructor:

Expand All @@ -24,6 +25,7 @@ You can construct a model with random weights by calling its constructor:
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenetv2()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -38,6 +40,7 @@ These can be constructed by passing ``pretrained=True``:
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenetv2(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Expand Down Expand Up @@ -88,6 +91,7 @@ Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
Inception v3 22.55 6.44
GoogleNet 30.22 10.47
ShuffleNet V2 30.64 11.68
================================ ============= =============


Expand All @@ -98,6 +102,7 @@ GoogleNet 30.22 10.47
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -152,3 +157,8 @@ GoogLeNet

.. autofunction:: googlenet

ShuffleNet v2
-------------

.. autofunction:: shufflenet

1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .shufflenetv2 import *
180 changes: 180 additions & 0 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import functools

import torch
import torch.nn as nn

__all__ = ['ShuffleNetV2', 'shufflenetv2',
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']

model_urls = {
'shufflenetv2_x0.5':
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt',
'shufflenetv2_x1.0':
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt',
'shufflenetv2_x1.5': None,
'shufflenetv2_x2.0': None,
}


def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups

# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)

x = torch.transpose(x, 1, 2).contiguous()

# flatten
x = x.view(batchsize, -1, height, width)

return x


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()

if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride

branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)

pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False)
dw_conv33 = functools.partial(self.depthwise_conv,
kernel_size=3, stride=self.stride, padding=1)

if self.stride > 1:
self.branch1 = nn.Sequential(
dw_conv33(inp, inp),
nn.BatchNorm2d(inp),
pw_conv11(inp, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)

self.branch2 = nn.Sequential(
pw_conv11(inp if (self.stride > 1) else branch_features, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
dw_conv33(branch_features, branch_features),
nn.BatchNorm2d(branch_features),
pw_conv11(branch_features, branch_features),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)

@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

out = channel_shuffle(out, 2)

return out


class ShuffleNetV2(nn.Module):
def __init__(self, num_classes=1000, input_size=224, width_mult=1):
super(ShuffleNetV2, self).__init__()

try:
self.stage_out_channels = self._getStages(float(width_mult))
except KeyError:
raise ValueError('width_mult {} is not supported'.format(width_mult))

input_channels = 3
output_channels = self.stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
stage_repeats = [4, 8, 4]
for name, repeats, output_channels in zip(
stage_names, stage_repeats, self.stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels

output_channels = self.stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)

self.fc = nn.Linear(output_channels, num_classes)

def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # globalpool
x = self.fc(x)
return x

@staticmethod
def _getStages(mult):
stages = {
'0.5': [24, 48, 96, 192, 1024],
'1.0': [24, 116, 232, 464, 1024],
'1.5': [24, 176, 352, 704, 1024],
'2.0': [24, 244, 488, 976, 2048],
}
return stages[str(mult)]


def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=1, **kwargs):
model = ShuffleNetV2(num_classes=num_classes, input_size=input_size, width_mult=width_mult)

if pretrained:
# change width_mult to float
if isinstance(width_mult, int):
width_mult = float(width_mult)
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
try:
model_url = model_urls[model_type.lower()]
except KeyError:
raise ValueError('model {} is not support'.format(model_type))
if model_url is None:
raise NotImplementedError('pretrained {} is not supported'.format(model_type))
model.load_state_dict(torch.utils.model_zoo.load_url(model_url))

return model


def shufflenetv2_x0_5(pretrained=False, num_classes=1000, input_size=224, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 0.5)


def shufflenetv2_x1_0(pretrained=False, num_classes=1000, input_size=224, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 1)


def shufflenetv2_x1_5(pretrained=False, num_classes=1000, input_size=224, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 1.5)


def shufflenetv2_x2_0(pretrained=False, num_classes=1000, input_size=224, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 2)