diff --git a/test/test_models.py b/test/test_models.py index 3696e78db30..00ea4e65f93 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -20,6 +20,11 @@ def get_available_detection_models(): return [k for k, v in models.detection.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] +def get_available_video_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + class Tester(unittest.TestCase): def _test_classification_model(self, name, input_shape): # passing num_class equal to a number other than 1000 helps in making the test @@ -53,6 +58,16 @@ def _test_detection_model(self, name): self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) + def _test_video_model(self, name): + # the default input shape is + # bs * num_channels * clip_len * h *w + input_shape = (1, 3, 8, 112, 112) + # test both basicblock and Bottleneck + model = models.video.__dict__[name](num_classes=50) + x = torch.rand(input_shape) + out = model(x) + self.assertEqual(out.shape[-1], 50) + def _make_sliced_model(self, model, stop_layer): layers = OrderedDict() for name, layer in model.named_children(): @@ -130,6 +145,12 @@ def do_test(self, model_name=model_name): setattr(Tester, "test_" + model_name, do_test) +for model_name in get_available_video_models(): + + def do_test(self, model_name=model_name): + self._test_video_model(model_name) + + setattr(Tester, "test_" + model_name, do_test) if __name__ == '__main__': unittest.main() diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index f4b76156012..413ecf7b456 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -10,3 +10,4 @@ from .shufflenetv2 import * from . import segmentation from . import detection +from . import video diff --git a/torchvision/models/video/__init__.py b/torchvision/models/video/__init__.py new file mode 100644 index 00000000000..e6e663c50c0 --- /dev/null +++ b/torchvision/models/video/__init__.py @@ -0,0 +1,3 @@ +from .r3d import * +from .r2plus1d import * +from .mixed_conv import * diff --git a/torchvision/models/video/_utils.py b/torchvision/models/video/_utils.py new file mode 100644 index 00000000000..03feb83b8c4 --- /dev/null +++ b/torchvision/models/video/_utils.py @@ -0,0 +1,72 @@ +import torch.nn as nn + + +__all__ = ["Conv3DSimple", "Conv2Plus1D", "Conv3DNoTemporal"] + + +class Conv3DSimple(nn.Conv3d): + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DSimple, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class Conv2Plus1D(nn.Sequential): + + def __init__(self, + in_planes, + out_planes, + midplanes, + stride=1, + padding=1): + conv1 = [ + nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), + stride=(1, stride, stride), padding=(0, padding, padding), + bias=False), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), + stride=(stride, 1, 1), padding=(padding, 0, 0), + bias=False) + ] + super(Conv2Plus1D, self).__init__(*conv1) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class Conv3DNoTemporal(nn.Conv3d): + + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DNoTemporal, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return (1, stride, stride) diff --git a/torchvision/models/video/mixed_conv.py b/torchvision/models/video/mixed_conv.py new file mode 100644 index 00000000000..3bd631fc1e2 --- /dev/null +++ b/torchvision/models/video/mixed_conv.py @@ -0,0 +1,78 @@ +import torch.nn as nn + +from ._utils import Conv3DSimple, Conv3DNoTemporal +from .video_stems import get_default_stem +from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck + + +__all__ = ["mc3_18"] + + +def _mcX(model_depth, X=3, use_pool1=False, **kwargs): + """Generate mixed convolution network as in + https://arxiv.org/abs/1711.11248 + + Args: + model_depth (int): trunk depth - supports most resnet depths + X (int): Up to which layers are convolutions 3D + use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False. + + Returns: + nn.Module: mcX video trunk + """ + assert X > 1 and X <= 5 + conv_makers = [Conv3DSimple] * (X - 2) + while len(conv_makers) < 5: + conv_makers.append(Conv3DNoTemporal) + + if model_depth < 50: + block = BasicBlock + else: + block = Bottleneck + + model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth, + stem=get_default_stem(use_pool1=use_pool1), **kwargs) + + return model + + +def _rmcX(model_depth, X=3, use_pool1=False, **kwargs): + """Generate reverse mixed convolution network as in + https://arxiv.org/abs/1711.11248 + + Args: + model_depth (int): trunk depth - supports most resnet depths + X (int): Up to which layers are convolutions 2D + use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False. + + Returns: + nn.Module: mcX video trunk + """ + assert X > 1 and X <= 5 + + conv_makers = [Conv3DNoTemporal] * (X - 2) + while len(conv_makers) < 5: + conv_makers.append(Conv3DSimple) + + if model_depth < 50: + block = BasicBlock + else: + block = Bottleneck + + model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth, + stem=get_default_stem(use_pool1=use_pool1), **kwargs) + + return model + + +def mc3_18(use_pool1=False, **kwargs): + """Constructor for 18 layer Mixed Convolution network as in + https://arxiv.org/abs/1711.11248 + + Args: + use_pool1 (bool, optional): Include pooling in the resnet stem. Defaults to False. + + Returns: + nn.Module: MC3 Network definitino + """ + return _mcX(18, 3, use_pool1, **kwargs) diff --git a/torchvision/models/video/r2plus1d.py b/torchvision/models/video/r2plus1d.py new file mode 100644 index 00000000000..0ab13419570 --- /dev/null +++ b/torchvision/models/video/r2plus1d.py @@ -0,0 +1,43 @@ +import torch.nn as nn + +from ._utils import Conv2Plus1D +from .video_stems import get_r2plus1d_stem +from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck + + +__all__ = ["r2plus1d_18"] + + +def _r2plus1d(model_depth, use_pool1=False, **kwargs): + """Constructor for R(2+1)D network as described in + https://arxiv.org/abs/1711.11248 + + Args: + model_depth (int): Depth of the model - standard resnet depths apply + use_pool1 (bool, optional): Should we use the pooling layer? Defaults to False + Returns: + nn.Module: An R(2+1)D video backbone + """ + convs = [Conv2Plus1D] * 4 + if model_depth < 50: + block = BasicBlock + else: + block = Bottleneck + + model = VideoTrunkBuilder( + block=block, conv_makers=convs, model_depth=model_depth, + stem=get_r2plus1d_stem(use_pool1), **kwargs) + return model + + +def r2plus1d_18(use_pool1=False, **kwargs): + """Constructor for the 18 layer deep R(2+1)D network as in + https://arxiv.org/abs/1711.11248 + + Args: + use_pool1 (bool, optional): Include pooling in the resnet stem. Defaults to False. + + Returns: + nn.Module: R(2+1)D-18 network + """ + return _r2plus1d(18, use_pool1, **kwargs) diff --git a/torchvision/models/video/r3d.py b/torchvision/models/video/r3d.py new file mode 100644 index 00000000000..08fe4f3b219 --- /dev/null +++ b/torchvision/models/video/r3d.py @@ -0,0 +1,43 @@ +import torch.nn as nn + +from ._utils import Conv3DSimple +from .video_stems import get_default_stem +from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck + +__all__ = ["r3d_18"] + + +def _r3d(model_depth, use_pool1=False, **kwargs): + """Constructor of a r3d network as in + https://arxiv.org/abs/1711.11248 + + Args: + model_depth (int): resnet trunk depth + use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False + + Returns: + nn.Module: R3D network trunk + """ + + conv_makers = [Conv3DSimple] * 4 + if model_depth < 50: + block = BasicBlock + else: + block = Bottleneck + + model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth, + stem=get_default_stem(use_pool1=use_pool1), **kwargs) + return model + + +def r3d_18(use_pool1=False, **kwargs): + """Construct 18 layer Resnet3D model as in + https://arxiv.org/abs/1711.11248 + + Args: + use_pool1 (bool, optional): Include pooling in resnet stem. Defaults to False. + + Returns: + nn.Module: R3D-18 network + """ + return _r3d(18, use_pool1, **kwargs) diff --git a/torchvision/models/video/video_stems.py b/torchvision/models/video/video_stems.py new file mode 100644 index 00000000000..4813aa84390 --- /dev/null +++ b/torchvision/models/video/video_stems.py @@ -0,0 +1,48 @@ +import torch.nn as nn + + +def get_default_stem(use_pool1=False): + """The default conv-batchnorm-relu(-maxpool) stem + + Args: + use_pool1 (bool, optional): Should the stem include the default maxpool? Defaults to False. + + Returns: + nn.Sequential: Conv1 stem of resnet based models. + """ + + m = [ + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), + padding=(1, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)] + if use_pool1: + m.append(nn. MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)) + return nn.Sequential(*m) + + +def get_r2plus1d_stem(use_pool1=False): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution + + Args: + use_pool1 (bool, optional): Should the stem contain pool1 layer. Defaults to False. + + Returns: + nn.Sequential: the stem of the conv-separated network. + """ + + m = [ + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), + stride=(1, 2, 2), padding=(0, 3, 3), + bias=False), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), + stride=(1, 1, 1), padding=(1, 0, 0), + bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)] + + if use_pool1: + m.append(nn. MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)) + return nn.Sequential(*m) diff --git a/torchvision/models/video/video_trunk.py b/torchvision/models/video/video_trunk.py new file mode 100644 index 00000000000..e28d60c611b --- /dev/null +++ b/torchvision/models/video/video_trunk.py @@ -0,0 +1,189 @@ +import inspect +import torch +import torch.nn as nn + +from .video_stems import get_default_stem +from ._utils import Conv3DNoTemporal + + +BLOCK_CONFIG = { + 10: (1, 1, 1, 1), + 16: (2, 2, 2, 1), + 18: (2, 2, 2, 2), + 26: (2, 3, 4, 3), + 34: (3, 4, 6, 3), + 50: (3, 4, 6, 3), + 101: (3, 4, 23, 3), + 152: (3, 8, 36, 3) +} + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), + nn.BatchNorm3d(planes) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + + super(Bottleneck, self).__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class VideoTrunkBuilder(nn.Module): + + def __init__(self, block, conv_makers, model_depth, + stem=None, + num_classes=400, + zero_init_residual=False): + """Generic resnet video generator. + + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + model_depth (int): depth of the model; supports traditional resnet depths . + stem (nn.Sequential, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super(VideoTrunkBuilder, self).__init__() + layers = BLOCK_CONFIG[model_depth] + self.inplanes = 64 + + if stem is None: + self.conv1 = get_default_stem() + else: + self.conv1 = stem + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + x = self.fc(x) + + return x + + def _make_layer(self, block, conv_builder, planes, blocks, stride=1): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion) + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0)