Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StochasticDepth implementation #4301

Merged
merged 11 commits into from
Aug 20, 2021
35 changes: 35 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
from abc import ABC, abstractmethod
import pytest
import random

import numpy as np

Expand All @@ -13,6 +14,11 @@
from torchvision import ops
from typing import Tuple

try:
from scipy import stats
except ImportError:
stats = None


class RoIOpTester(ABC):
dtype = torch.float64
Expand Down Expand Up @@ -1000,5 +1006,34 @@ def gen_iou_check(box, expected, tolerance=1e-4):
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)


class TestStochasticDepth:
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
datumbox marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize('mode', ["batch", "row"])
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
def test_stochastic_depth(self, mode, p):
random.seed(42)
datumbox marked this conversation as resolved.
Show resolved Hide resolved
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(mode=mode, p=p).to(device=x.device, dtype=x.dtype)
layer.__repr__()

trials = 250
num_samples = 0
counts = 0
for _ in range(trials):
out = layer(x)
non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
if mode == "batch":
if non_zero_count == 0:
counts += 1
num_samples += 1
elif mode == "row":
counts += batch_size - non_zero_count
num_samples += batch_size

p_value = stats.binom_test(counts, num_samples, p=p)
assert p_value > 0.0001


if __name__ == '__main__':
pytest.main([__file__])
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .stochastic_depth import stochastic_depth, StochasticDepth

from ._register_onnx_ops import _register_custom_op

Expand All @@ -20,5 +21,5 @@
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork',
'sigmoid_focal_loss'
'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth'
]
56 changes: 56 additions & 0 deletions torchvision/ops/stochastic_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch import nn, Tensor


def stochastic_depth(input: Tensor, mode: str, p: float, training: bool = True) -> Tensor:
datumbox marked this conversation as resolved.
Show resolved Hide resolved
"""
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
branches of residual architectures.

Args:
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
mode (str): ``"batch"`` or ``"row"``.
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
randomly selected rows from the batch.
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
p (float): probability of the input to be zeroed.
training: apply dropout if is ``True``. Default: ``True``

Returns:
Tensor[N, ...]: The randomly zeroed tensor.
"""
if p < 0.0 or p > 1.0:
raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p))
if not training or p == 0.0:
return input

survival_rate = 1.0 - p
if mode == "batch":
keep = torch.rand(size=(1, ), dtype=input.dtype, device=input.device) < survival_rate
elif mode == "row":
keep = torch.rand(size=(input.size(0),), dtype=input.dtype, device=input.device) < survival_rate
keep = keep[(None, ) * (input.ndim - 1)].T
else:
raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode))
return input / survival_rate * keep
datumbox marked this conversation as resolved.
Show resolved Hide resolved
datumbox marked this conversation as resolved.
Show resolved Hide resolved


class StochasticDepth(nn.Module):
"""
See :func:`stochastic_depth`.
"""
def __init__(self, mode: str, p: float) -> None:
datumbox marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.mode = mode
self.p = p

def forward(self, input: Tensor) -> Tensor:
return stochastic_depth(input, self.mode, self.p, self.training)

def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'mode=' + str(self.mode)
tmpstr += ', p=' + str(self.p)
tmpstr += ')'
return tmpstr