-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the S3D architecture to TorchVision (#6412)
* S3D initial commit * add model builder code and docstrings * change classifier submodule, populate weights enum * fix change of block args from List[List[int]] to ints * add VideoClassification to transforms * edit weights url for testing, add s3d to models.video init * norm_layer changes * norm_layer and args fix * Overwrite default dropout * Remove docs from internal submodules. * Fix tests * Adding documentation. * Link doc from main models.rst * Fix min_temporal_size * Adding crop/resize parameters in references script * Release weights. * Refactor dropout. * Adding the weights table in the doc Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
- Loading branch information
1 parent
9c3e2bf
commit 6de7021
Showing
6 changed files
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -516,6 +516,7 @@ pre-trained weights: | |
|
||
models/video_mvit | ||
models/video_resnet | ||
models/video_s3d | ||
|
||
| | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Video S3D | ||
========= | ||
|
||
.. currentmodule:: torchvision.models.video | ||
|
||
The S3D model is based on the | ||
`Rethinking Spatiotemporal Feature Learning: Speed-Accuracy Trade-offs in Video Classification | ||
<https://arxiv.org/abs/1712.04851>`__ paper. | ||
|
||
|
||
Model builders | ||
-------------- | ||
|
||
The following model builders can be used to instantiate an S3D model, with or | ||
without pre-trained weights. All the model builders internally rely on the | ||
``torchvision.models.video.S3D`` base class. Please refer to the `source | ||
code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/s3d.py>`_ for | ||
more details about this class. | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:template: function.rst | ||
|
||
s3d |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .mvit import * | ||
from .resnet import * | ||
from .s3d import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from functools import partial | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
from torch import nn | ||
from torchvision.ops.misc import Conv3dNormActivation | ||
|
||
from ...transforms._presets import VideoClassification | ||
from ...utils import _log_api_usage_once | ||
from .._api import register_model, Weights, WeightsEnum | ||
from .._meta import _KINETICS400_CATEGORIES | ||
from .._utils import _ovewrite_named_param | ||
|
||
|
||
__all__ = [ | ||
"S3D", | ||
"S3D_Weights", | ||
"s3d", | ||
] | ||
|
||
|
||
class TemporalSeparableConv(nn.Sequential): | ||
def __init__( | ||
self, | ||
in_planes: int, | ||
out_planes: int, | ||
kernel_size: int, | ||
stride: int, | ||
padding: int, | ||
norm_layer: Callable[..., nn.Module], | ||
): | ||
super().__init__( | ||
Conv3dNormActivation( | ||
in_planes, | ||
out_planes, | ||
kernel_size=(1, kernel_size, kernel_size), | ||
stride=(1, stride, stride), | ||
padding=(0, padding, padding), | ||
bias=False, | ||
norm_layer=norm_layer, | ||
), | ||
Conv3dNormActivation( | ||
out_planes, | ||
out_planes, | ||
kernel_size=(kernel_size, 1, 1), | ||
stride=(stride, 1, 1), | ||
padding=(padding, 0, 0), | ||
bias=False, | ||
norm_layer=norm_layer, | ||
), | ||
) | ||
|
||
|
||
class SepInceptionBlock3D(nn.Module): | ||
def __init__( | ||
self, | ||
in_planes: int, | ||
b0_out: int, | ||
b1_mid: int, | ||
b1_out: int, | ||
b2_mid: int, | ||
b2_out: int, | ||
b3_out: int, | ||
norm_layer: Callable[..., nn.Module], | ||
): | ||
super().__init__() | ||
|
||
self.branch0 = Conv3dNormActivation(in_planes, b0_out, kernel_size=1, stride=1, norm_layer=norm_layer) | ||
self.branch1 = nn.Sequential( | ||
Conv3dNormActivation(in_planes, b1_mid, kernel_size=1, stride=1, norm_layer=norm_layer), | ||
TemporalSeparableConv(b1_mid, b1_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), | ||
) | ||
self.branch2 = nn.Sequential( | ||
Conv3dNormActivation(in_planes, b2_mid, kernel_size=1, stride=1, norm_layer=norm_layer), | ||
TemporalSeparableConv(b2_mid, b2_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), | ||
) | ||
self.branch3 = nn.Sequential( | ||
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), | ||
Conv3dNormActivation(in_planes, b3_out, kernel_size=1, stride=1, norm_layer=norm_layer), | ||
) | ||
|
||
def forward(self, x): | ||
x0 = self.branch0(x) | ||
x1 = self.branch1(x) | ||
x2 = self.branch2(x) | ||
x3 = self.branch3(x) | ||
out = torch.cat((x0, x1, x2, x3), 1) | ||
|
||
return out | ||
|
||
|
||
class S3D(nn.Module): | ||
"""S3D main class. | ||
Args: | ||
num_class (int): number of classes for the classification task. | ||
dropout (float): dropout probability. | ||
norm_layer (Optional[Callable]): Module specifying the normalization layer to use. | ||
Inputs: | ||
x (Tensor): batch of videos with dimensions (batch, channel, time, height, width) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_classes: int = 400, | ||
dropout: float = 0.0, | ||
norm_layer: Optional[Callable[..., torch.nn.Module]] = None, | ||
) -> None: | ||
super().__init__() | ||
_log_api_usage_once(self) | ||
|
||
if norm_layer is None: | ||
norm_layer = partial(nn.BatchNorm3d, eps=0.001, momentum=0.001) | ||
|
||
self.features = nn.Sequential( | ||
TemporalSeparableConv(3, 64, 7, 2, 3, norm_layer), | ||
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), | ||
Conv3dNormActivation( | ||
64, | ||
64, | ||
kernel_size=1, | ||
stride=1, | ||
norm_layer=norm_layer, | ||
), | ||
TemporalSeparableConv(64, 192, 3, 1, 1, norm_layer), | ||
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), | ||
SepInceptionBlock3D(192, 64, 96, 128, 16, 32, 32, norm_layer), | ||
SepInceptionBlock3D(256, 128, 128, 192, 32, 96, 64, norm_layer), | ||
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), | ||
SepInceptionBlock3D(480, 192, 96, 208, 16, 48, 64, norm_layer), | ||
SepInceptionBlock3D(512, 160, 112, 224, 24, 64, 64, norm_layer), | ||
SepInceptionBlock3D(512, 128, 128, 256, 24, 64, 64, norm_layer), | ||
SepInceptionBlock3D(512, 112, 144, 288, 32, 64, 64, norm_layer), | ||
SepInceptionBlock3D(528, 256, 160, 320, 32, 128, 128, norm_layer), | ||
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)), | ||
SepInceptionBlock3D(832, 256, 160, 320, 32, 128, 128, norm_layer), | ||
SepInceptionBlock3D(832, 384, 192, 384, 48, 128, 128, norm_layer), | ||
) | ||
self.avgpool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(p=dropout), | ||
nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.avgpool(x) | ||
x = self.classifier(x) | ||
x = torch.mean(x, dim=(2, 3, 4)) | ||
return x | ||
|
||
|
||
class S3D_Weights(WeightsEnum): | ||
KINETICS400_V1 = Weights( | ||
url="https://download.pytorch.org/models/s3d-1bd8ae63.pth", | ||
transforms=partial( | ||
VideoClassification, | ||
crop_size=(224, 224), | ||
resize_size=(256, 256), | ||
mean=(0.5, 0.5, 0.5), | ||
std=(0.5, 0.5, 0.5), | ||
), | ||
meta={ | ||
"min_size": (224, 224), | ||
"min_temporal_size": 14, | ||
"categories": _KINETICS400_CATEGORIES, | ||
"recipe": "https://github.com/pytorch/vision/pull/6412#issuecomment-1219687434", | ||
"_docs": ( | ||
"The weights are ported from a community repository. The accuracies are estimated on clip-level " | ||
"with parameters `frame_rate=15`, `clips_per_video=1`, and `clip_len=128`." | ||
), | ||
"num_params": 8320048, | ||
"_metrics": { | ||
"Kinetics-400": { | ||
"acc@1": 67.315, | ||
"acc@5": 87.593, | ||
} | ||
}, | ||
}, | ||
) | ||
DEFAULT = KINETICS400_V1 | ||
|
||
|
||
@register_model() | ||
def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwargs: Any) -> S3D: | ||
"""Construct Separable 3D CNN model. | ||
Reference: `Rethinking Spatiotemporal Feature Learning <https://arxiv.org/abs/1712.04851>`__. | ||
Args: | ||
weights (:class:`~torchvision.models.video.S3D_Weights`, optional): The | ||
pretrained weights to use. See | ||
:class:`~torchvision.models.video.S3D_Weights` | ||
below for more details, and possible values. By default, no | ||
pre-trained weights are used. | ||
progress (bool): If True, displays a progress bar of the download to stderr. Default is True. | ||
**kwargs: parameters passed to the ``torchvision.models.video.S3D`` base class. | ||
Please refer to the `source code | ||
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/s3d.py>`_ | ||
for more details about this class. | ||
.. autoclass:: torchvision.models.video.S3D_Weights | ||
:members: | ||
""" | ||
weights = S3D_Weights.verify(weights) | ||
|
||
if weights is not None: | ||
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | ||
|
||
model = S3D(**kwargs) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.get_state_dict(progress=progress)) | ||
|
||
return model |