Skip to content

Commit

Permalink
Add tests for segmentation models
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed May 7, 2019
1 parent 7c9f4aa commit 349e611
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 17 deletions.
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def main(args):
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)

model = torchvision.models.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss)
model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss)
model.to(device)
if args.distributed:
model = torch.nn.utils.convert_sync_batchnorm(model)
Expand Down
34 changes: 29 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import unittest


def get_available_models():
def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0]]
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


def get_available_segmentation_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.segmentation.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


class Tester(unittest.TestCase):
def _test_model(self, name, input_shape):
def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.__dict__[name](num_classes=50)
Expand All @@ -20,6 +25,16 @@ def _test_model(self, name, input_shape):
out = model(x)
self.assertEqual(out.shape[-1], 50)

def _test_segmentation_model(self, name):
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape)
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))

def _make_sliced_model(self, model, stop_layer):
layers = OrderedDict()
for name, layer in model.named_children():
Expand All @@ -41,14 +56,23 @@ def test_resnet_dilation(self):
self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))


for model_name in get_available_models():
for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_model(model_name, input_shape)
self._test_classification_model(model_name, input_shape)

setattr(Tester, "test_" + model_name, do_test)


for model_name in get_available_segmentation_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_segmentation_model(model_name)

setattr(Tester, "test_" + model_name, do_test)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .googlenet import *
from .mobilenet import *
from .shufflenetv2 import *
from .segmentation import *
from . import segmentation
6 changes: 6 additions & 0 deletions torchvision/models/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def forward(self, x):

return result


class FCN(_SimpleSegmentationModel):
pass


class DeepLabV3(_SimpleSegmentationModel):
pass

Expand All @@ -50,12 +52,14 @@ def __init__(self, in_channels, channels):
]

super(FCNHead, self).__init__(*layers)
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
"""


class DeepLabHead(nn.Sequential):
Expand All @@ -67,12 +71,14 @@ def __init__(self, in_channels, num_classes):
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
"""


class ASPPConv(nn.Sequential):
Expand Down
20 changes: 10 additions & 10 deletions torchvision/models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .deeplabv3 import FCN, FCNHead, DeepLabHead, DeepLabV3


def _segm_resnet(name, backbone_name, num_classes, aux):
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name](
pretrained=True,
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])

return_layers = {'layer4': 'out'}
Expand All @@ -30,29 +30,29 @@ def _segm_resnet(name, backbone_name, num_classes, aux):
return model


def fcn_resnet50(pretrained=False, num_classes=21, aux_loss=None):
model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss)
def fcn_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained:
pass
return model


def fcn_resnet101(pretrained=False, num_classes=21, aux_loss=None):
model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss)
def fcn_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
pass
return model


def deeplabv3_resnet50(pretrained=False, num_classes=21, aux_loss=None):
model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss)
def deeplabv3_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained:
pass
return model


def deeplabv3_resnet101(pretrained=False, num_classes=21, aux_loss=None):
model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss)
def deeplabv3_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
pass
return model

0 comments on commit 349e611

Please sign in to comment.