From b9b69ba3e7c746f0ca7312a70532eae2282e03ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Sat, 8 Dec 2018 14:16:45 +0100 Subject: [PATCH 01/11] Add GoogLeNet (Inception v1) --- docs/source/models.rst | 9 ++ torchvision/models/__init__.py | 1 + torchvision/models/googlenet.py | 166 ++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 torchvision/models/googlenet.py diff --git a/docs/source/models.rst b/docs/source/models.rst index 674ac052c8d..e8876d445cd 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -10,6 +10,7 @@ architectures: - `SqueezeNet`_ - `DenseNet`_ - `Inception`_ v3 +- `GoogLeNet`_ You can construct a model with random weights by calling its constructor: @@ -22,6 +23,7 @@ You can construct a model with random weights by calling its constructor: squeezenet = models.squeezenet1_0() densenet = models.densenet161() inception = models.inception_v3() + googlenet = models.googlenet() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -35,6 +37,7 @@ These can be constructed by passing ``pretrained=True``: vgg16 = models.vgg16(pretrained=True) densenet = models.densenet161(pretrained=True) inception = models.inception_v3(pretrained=True) + googlenet = models.googlenet(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -93,6 +96,7 @@ Inception v3 22.55 6.44 .. _SqueezeNet: https://arxiv.org/abs/1602.07360 .. _DenseNet: https://arxiv.org/abs/1608.06993 .. _Inception: https://arxiv.org/abs/1512.00567 +.. _GoogLeNet: https://arxiv.org/abs/1409.4842 .. currentmodule:: torchvision.models @@ -142,3 +146,8 @@ Inception v3 .. autofunction:: inception_v3 +GoogLeNet +------------ + +.. autofunction:: googlenet + diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 079992e0269..7437c51597f 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -4,3 +4,4 @@ from .squeezenet import * from .inception import * from .densenet import * +from .googlenet import * diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py new file mode 100644 index 00000000000..d32816f759d --- /dev/null +++ b/torchvision/models/googlenet.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import model_zoo + +__all__ = ['GoogLeNet', 'googlenet'] + +model_urls = { + 'googlenet': '' +} + + +def googlenet(pretrained=False, **kwargs): + r"""GoogLeNet (Inception v1) model architecture from + `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + model = GoogLeNet(**kwargs) + model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) + return model + + return GoogLeNet(**kwargs) + + +class GoogLeNet(nn.Module): + + def __init__(self, num_classes=1000, aux_logits=True): + super(GoogLeNet, self).__init__() + self.aux_logits = aux_logits + + self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1) + self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001) + self.conv2 = BasicConv2d(64, 64, kernel_size=1) + self.conv3 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1) + self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001) + self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1) + + self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.maxpool3 = nn.MaxPool2d(3, stride=2) + + self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1) + + self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) + if aux_logits: + self.aux1 = InceptionAux(512, num_classes) + self.aux2 = InceptionAux(528, num_classes) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.dropout = nn.Dropout(0.4) + self.fc = nn.Linear(1024, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + import scipy.stats as stats + X = stats.truncnorm(-2, 2, scale=0.01) + values = torch.Tensor(X.rvs(m.weight.numel())) + values = values.view(m.weight.size()) + m.weight.data.copy_(values) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool1(x) + x = self.lrn1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.lrn2(x) + x = self.maxpool2(x) + + x = self.inception3a(x) + x = self.inception3b(x) + x = self.maxpool3(x) + x = self.inception4a(x) + if self.training and self.aux_logits: + aux1 = self.aux1(x) + + x = self.inception4b(x) + x = self.inception4c(x) + x = self.inception4d(x) + if self.training and self.aux_logits: + aux2 = self.aux2(x) + + x = self.inception4e(x) + x = self.maxpool4(x) + x = self.inception5a(x) + x = self.inception5b(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.fc(x) + if self.training and self.aux_logits: + return aux1, aux2, x + return x + + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): + super(Inception, self).__init__() + + self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) + + self.branch2 = nn.Sequential( + BasicConv2d(in_channels, ch3x3red, kernel_size=1, stride=1), + BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) + ) + + self.branch3 = nn.Sequential( + BasicConv2d(in_channels, ch5x5red, kernel_size=1), + BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) + ) + + self.branch4 = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + BasicConv2d(in_channels, pool_proj, kernel_size=1) + ) + + def forward(self, x): + branch1 = self.branch1(x) + branch2 = self.branch2(x) + branch3 = self.branch3(x) + branch4 = self.branch4(x) + + outputs = [branch1, branch2, branch3, branch4] + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes): + super(InceptionAux, self).__init__() + self.conv = BasicConv2d(in_channels, 128, kernel_size=1) + + self.fc1 = nn.Linear(128 * 3 * 3, 1024) + self.fc2 = nn.Linear(1024, num_classes) + + def forward(self, x): + x = F.avg_pool2d(x, kernel_size=5, stride=3) + + x = self.conv(x) + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = F.dropout(x, 0.7, training=self.training) + x = self.fc2(x) + + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) + + def forward(self, x): + x = self.conv(x) + return F.relu(x, inplace=True) From c0ef07926483dd681cd2517b395b8f5fcc1ca65d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Sun, 30 Dec 2018 21:18:13 +0100 Subject: [PATCH 02/11] Fix missing padding --- torchvision/models/googlenet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index d32816f759d..faed216ca5e 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -40,7 +40,7 @@ def __init__(self, num_classes=1000, aux_logits=True): self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) - self.maxpool3 = nn.MaxPool2d(3, stride=2) + self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1) self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) @@ -140,7 +140,7 @@ def __init__(self, in_channels, num_classes): super(InceptionAux, self).__init__() self.conv = BasicConv2d(in_channels, 128, kernel_size=1) - self.fc1 = nn.Linear(128 * 3 * 3, 1024) + self.fc1 = nn.Linear(128 * 4 * 4, 1024) self.fc2 = nn.Linear(1024, num_classes) def forward(self, x): From f0ae7d15b5f56f6e48ac1e4eabfe39ed5f73b584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Sun, 30 Dec 2018 22:44:19 +0100 Subject: [PATCH 03/11] Add missing ReLu to aux classifier --- torchvision/models/googlenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index faed216ca5e..379ed0e0417 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -148,7 +148,7 @@ def forward(self, x): x = self.conv(x) x = x.view(x.size(0), -1) - x = self.fc1(x) + x = F.relu(self.fc1(x), inplace=True) x = F.dropout(x, 0.7, training=self.training) x = self.fc2(x) From 6c0d34d250bbc85a556f58f6aed8190ed14e79eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Thu, 3 Jan 2019 18:09:18 +0100 Subject: [PATCH 04/11] Add Batch normalized version of GoogLeNet --- docs/source/models.rst | 1 + torchvision/models/googlenet.py | 79 +++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index e8876d445cd..50c0bb93d77 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -150,4 +150,5 @@ GoogLeNet ------------ .. autofunction:: googlenet +.. autofunction:: googlenet_bn diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 379ed0e0417..d9297b328bb 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -3,10 +3,11 @@ import torch.nn.functional as F from torch.utils import model_zoo -__all__ = ['GoogLeNet', 'googlenet'] +__all__ = ['GoogLeNet', 'googlenet', 'googlenet_bn'] model_urls = { - 'googlenet': '' + 'googlenet': '', + 'googlenet_bn': '' } @@ -24,36 +25,50 @@ def googlenet(pretrained=False, **kwargs): return GoogLeNet(**kwargs) +def googlenet_bn(pretrained=False, **kwargs): + r"""GoogLeNet (Inception v1) model architecture with batch normalization from + `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + model = GoogLeNet(batch_norm=True, **kwargs) + model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn'])) + return model + + return GoogLeNet(batch_norm=True, **kwargs) + + class GoogLeNet(nn.Module): - def __init__(self, num_classes=1000, aux_logits=True): + def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False): super(GoogLeNet, self).__init__() self.aux_logits = aux_logits - self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1) self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001) - self.conv2 = BasicConv2d(64, 64, kernel_size=1) - self.conv3 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1) + self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1) + self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, stride=1, padding=1) self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001) self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1) - self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) - self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm) + self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, batch_norm) self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1) - self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) - self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) - self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) - self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) - self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, batch_norm) + self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, batch_norm) + self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, batch_norm) + self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, batch_norm) + self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, batch_norm) self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1) - self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) - self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) + self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, batch_norm) + self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, batch_norm) if aux_logits: - self.aux1 = InceptionAux(512, num_classes) - self.aux2 = InceptionAux(528, num_classes) + self.aux1 = InceptionAux(512, num_classes, batch_norm) + self.aux2 = InceptionAux(528, num_classes, batch_norm) self.avgpool = nn.AvgPool2d(7, stride=1) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(1024, num_classes) @@ -104,24 +119,24 @@ def forward(self, x): class Inception(nn.Module): - def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, batch_norm=False): super(Inception, self).__init__() - self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) + self.branch1 = BasicConv2d(in_channels, ch1x1, batch_norm, kernel_size=1) self.branch2 = nn.Sequential( - BasicConv2d(in_channels, ch3x3red, kernel_size=1, stride=1), - BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) + BasicConv2d(in_channels, ch3x3red, batch_norm, kernel_size=1, stride=1), + BasicConv2d(ch3x3red, ch3x3, batch_norm, kernel_size=3, padding=1) ) self.branch3 = nn.Sequential( - BasicConv2d(in_channels, ch5x5red, kernel_size=1), - BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) + BasicConv2d(in_channels, ch5x5red, batch_norm, kernel_size=1), + BasicConv2d(ch5x5red, ch5x5, batch_norm, kernel_size=5, padding=2) ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - BasicConv2d(in_channels, pool_proj, kernel_size=1) + BasicConv2d(in_channels, pool_proj, batch_norm, kernel_size=1) ) def forward(self, x): @@ -136,9 +151,9 @@ def forward(self, x): class InceptionAux(nn.Module): - def __init__(self, in_channels, num_classes): + def __init__(self, in_channels, num_classes, batch_norm=False): super(InceptionAux, self).__init__() - self.conv = BasicConv2d(in_channels, 128, kernel_size=1) + self.conv = BasicConv2d(in_channels, 128, batch_norm, kernel_size=1) self.fc1 = nn.Linear(128 * 4 * 4, 1024) self.fc2 = nn.Linear(1024, num_classes) @@ -157,10 +172,18 @@ def forward(self, x): class BasicConv2d(nn.Module): - def __init__(self, in_channels, out_channels, **kwargs): + def __init__(self, in_channels, out_channels, batch_norm=False, **kwargs): super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) + self.batch_norm = batch_norm + + if self.batch_norm: + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + else: + self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) def forward(self, x): x = self.conv(x) + if self.batch_norm: + x = self.bn(x) return F.relu(x, inplace=True) From 2c8caabda887d0c97f970a157ad88eaf24c20415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Wed, 9 Jan 2019 15:47:52 +0100 Subject: [PATCH 05/11] Use ceil_mode instead of padding and initialize weights using "xavier" --- torchvision/models/googlenet.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index d9297b328bb..9a7c6fa4c0a 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -46,23 +46,23 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False): self.aux_logits = aux_logits self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3) - self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1) + self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001) self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1) self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, stride=1, padding=1) self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001) - self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1) + self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm) self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, batch_norm) - self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1) + self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, batch_norm) self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, batch_norm) self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, batch_norm) self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, batch_norm) self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, batch_norm) - self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1) + self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, batch_norm) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, batch_norm) @@ -75,11 +75,9 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - import scipy.stats as stats - X = stats.truncnorm(-2, 2, scale=0.01) - values = torch.Tensor(X.rvs(m.weight.numel())) - values = values.view(m.weight.size()) - m.weight.data.copy_(values) + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.2) def forward(self, x): x = self.conv1(x) @@ -135,7 +133,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr ) self.branch4 = nn.Sequential( - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), BasicConv2d(in_channels, pool_proj, batch_norm, kernel_size=1) ) From 2c235ff2d727cce9a832bb9ebfb50ad302c4a92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Wed, 9 Jan 2019 19:23:28 +0100 Subject: [PATCH 06/11] Match BVLC GoogLeNet zero initialization of classifier --- torchvision/models/googlenet.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 9a7c6fa4c0a..4a91f05a0a5 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -18,6 +18,7 @@ def googlenet(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: + kwargs['init_weights'] = False model = GoogLeNet(**kwargs) model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) return model @@ -32,6 +33,7 @@ def googlenet_bn(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: + kwargs['init_weights'] = False model = GoogLeNet(batch_norm=True, **kwargs) model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn'])) return model @@ -41,7 +43,7 @@ def googlenet_bn(pretrained=False, **kwargs): class GoogLeNet(nn.Module): - def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False): + def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_weights=True): super(GoogLeNet, self).__init__() self.aux_logits = aux_logits @@ -73,11 +75,21 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False): self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(1024, num_classes) + if init_weights: + self._initialize_weights() + + def _initialize_weights(self): for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0.2) + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def forward(self, x): x = self.conv1(x) From 384df08e0e9886d0b75bb420a6bbc6b2c08c21ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Sun, 3 Feb 2019 11:59:41 +0100 Subject: [PATCH 07/11] Small cleanup --- torchvision/models/googlenet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 4a91f05a0a5..2b99811fb27 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -49,10 +49,10 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_wei self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) - self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001) + self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1) - self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, stride=1, padding=1) - self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001) + self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, padding=1) + self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm) @@ -135,7 +135,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr self.branch1 = BasicConv2d(in_channels, ch1x1, batch_norm, kernel_size=1) self.branch2 = nn.Sequential( - BasicConv2d(in_channels, ch3x3red, batch_norm, kernel_size=1, stride=1), + BasicConv2d(in_channels, ch3x3red, batch_norm, kernel_size=1), BasicConv2d(ch3x3red, ch3x3, batch_norm, kernel_size=3, padding=1) ) From f6409964d4d0dffa20f4e0c6a8b3d141f116a600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Tue, 26 Feb 2019 19:10:16 +0100 Subject: [PATCH 08/11] use adaptive avg pool --- torchvision/models/googlenet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 2b99811fb27..f3816190351 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -71,7 +71,7 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_wei if aux_logits: self.aux1 = InceptionAux(512, num_classes, batch_norm) self.aux2 = InceptionAux(528, num_classes, batch_norm) - self.avgpool = nn.AvgPool2d(7, stride=1) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(1024, num_classes) @@ -169,7 +169,7 @@ def __init__(self, in_channels, num_classes, batch_norm=False): self.fc2 = nn.Linear(1024, num_classes) def forward(self, x): - x = F.avg_pool2d(x, kernel_size=5, stride=3) + x = F.adaptive_avg_pool2d(x, (4, 4)) x = self.conv(x) x = x.view(x.size(0), -1) From d30200509ee1aa4a20d7e4197c7f14ae72ed5eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Thu, 7 Mar 2019 00:22:44 +0100 Subject: [PATCH 09/11] adjust network to match TensorFlow --- docs/source/models.rst | 1 - torchvision/models/googlenet.py | 98 ++++++++++++++------------------- 2 files changed, 41 insertions(+), 58 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 50c0bb93d77..e8876d445cd 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -150,5 +150,4 @@ GoogLeNet ------------ .. autofunction:: googlenet -.. autofunction:: googlenet_bn diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index f3816190351..31ae3418e2a 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -3,11 +3,11 @@ import torch.nn.functional as F from torch.utils import model_zoo -__all__ = ['GoogLeNet', 'googlenet', 'googlenet_bn'] +__all__ = ['GoogLeNet', 'googlenet'] model_urls = { - 'googlenet': '', - 'googlenet_bn': '' + # GoogLeNet ported from TensorFlow + 'googlenet': 'https://github.com/TheCodez/vision/releases/download/1.0/googlenet-1378be20.pth', } @@ -18,6 +18,8 @@ def googlenet(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: + if 'transform_input' not in kwargs: + kwargs['transform_input'] = True kwargs['init_weights'] = False model = GoogLeNet(**kwargs) model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) @@ -26,51 +28,35 @@ def googlenet(pretrained=False, **kwargs): return GoogLeNet(**kwargs) -def googlenet_bn(pretrained=False, **kwargs): - r"""GoogLeNet (Inception v1) model architecture with batch normalization from - `"Going Deeper with Convolutions" `_. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - if pretrained: - kwargs['init_weights'] = False - model = GoogLeNet(batch_norm=True, **kwargs) - model.load_state_dict(model_zoo.load_url(model_urls['googlenet_bn'])) - return model - - return GoogLeNet(batch_norm=True, **kwargs) - - class GoogLeNet(nn.Module): - def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False, init_weights=True): + def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): super(GoogLeNet, self).__init__() self.aux_logits = aux_logits + self.transform_input = transform_input - self.conv1 = BasicConv2d(3, 64, batch_norm, kernel_size=7, stride=2, padding=3) + self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) - self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) - self.conv2 = BasicConv2d(64, 64, batch_norm, kernel_size=1) - self.conv3 = BasicConv2d(64, 192, batch_norm, kernel_size=3, padding=1) - self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.conv2 = BasicConv2d(64, 64, kernel_size=1) + self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) - self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32, batch_norm) - self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64, batch_norm) + self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) - self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64, batch_norm) - self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64, batch_norm) - self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64, batch_norm) - self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64, batch_norm) - self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128, batch_norm) + self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) - self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128, batch_norm) - self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128, batch_norm) + self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) if aux_logits: - self.aux1 = InceptionAux(512, num_classes, batch_norm) - self.aux2 = InceptionAux(528, num_classes, batch_norm) + self.aux1 = InceptionAux(512, num_classes) + self.aux2 = InceptionAux(528, num_classes) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(1024, num_classes) @@ -92,12 +78,16 @@ def _initialize_weights(self): nn.init.constant_(m.bias, 0) def forward(self, x): + if self.transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.conv1(x) x = self.maxpool1(x) - x = self.lrn1(x) x = self.conv2(x) x = self.conv3(x) - x = self.lrn2(x) x = self.maxpool2(x) x = self.inception3a(x) @@ -129,24 +119,24 @@ def forward(self, x): class Inception(nn.Module): - def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, batch_norm=False): + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): super(Inception, self).__init__() - self.branch1 = BasicConv2d(in_channels, ch1x1, batch_norm, kernel_size=1) + self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) self.branch2 = nn.Sequential( - BasicConv2d(in_channels, ch3x3red, batch_norm, kernel_size=1), - BasicConv2d(ch3x3red, ch3x3, batch_norm, kernel_size=3, padding=1) + BasicConv2d(in_channels, ch3x3red, kernel_size=1), + BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) ) self.branch3 = nn.Sequential( - BasicConv2d(in_channels, ch5x5red, batch_norm, kernel_size=1), - BasicConv2d(ch5x5red, ch5x5, batch_norm, kernel_size=5, padding=2) + BasicConv2d(in_channels, ch5x5red, kernel_size=1), + BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), - BasicConv2d(in_channels, pool_proj, batch_norm, kernel_size=1) + BasicConv2d(in_channels, pool_proj, kernel_size=1) ) def forward(self, x): @@ -161,11 +151,11 @@ def forward(self, x): class InceptionAux(nn.Module): - def __init__(self, in_channels, num_classes, batch_norm=False): + def __init__(self, in_channels, num_classes): super(InceptionAux, self).__init__() - self.conv = BasicConv2d(in_channels, 128, batch_norm, kernel_size=1) + self.conv = BasicConv2d(in_channels, 128, kernel_size=1) - self.fc1 = nn.Linear(128 * 4 * 4, 1024) + self.fc1 = nn.Linear(2048, 1024) self.fc2 = nn.Linear(1024, num_classes) def forward(self, x): @@ -182,18 +172,12 @@ def forward(self, x): class BasicConv2d(nn.Module): - def __init__(self, in_channels, out_channels, batch_norm=False, **kwargs): + def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() - self.batch_norm = batch_norm - - if self.batch_norm: - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) - else: - self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x): x = self.conv(x) - if self.batch_norm: - x = self.bn(x) + x = self.bn(x) return F.relu(x, inplace=True) From c02bf0055541b526e10922ce7e75d9eb44671578 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 7 Mar 2019 09:55:05 -0800 Subject: [PATCH 10/11] Update url of pre-trained model and add classification results on ImageNet --- docs/source/models.rst | 1 + torchvision/models/googlenet.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index e8876d445cd..cdcd7533862 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -87,6 +87,7 @@ Densenet-169 24.00 7.00 Densenet-201 22.80 6.43 Densenet-161 22.35 6.20 Inception v3 22.55 6.44 +GoogleNet 31.67 11.45 ================================ ============= ============= diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 31ae3418e2a..5e2d48430a1 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -7,7 +7,7 @@ model_urls = { # GoogLeNet ported from TensorFlow - 'googlenet': 'https://github.com/TheCodez/vision/releases/download/1.0/googlenet-1378be20.pth', + 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', } From 23562cfacc5042f3f5e2e0c82670a0df1305b3b4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 7 Mar 2019 11:44:38 -0800 Subject: [PATCH 11/11] Bugfix that improves performance by 1 point --- docs/source/models.rst | 2 +- torchvision/models/googlenet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index cdcd7533862..308ba75481b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -87,7 +87,7 @@ Densenet-169 24.00 7.00 Densenet-201 22.80 6.43 Densenet-161 22.35 6.20 Inception v3 22.55 6.44 -GoogleNet 31.67 11.45 +GoogleNet 30.22 10.47 ================================ ============= ============= diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 5e2d48430a1..9f50d93147e 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -50,7 +50,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) - self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)