Skip to content

Commit

Permalink
[fbsync] Add SSD architecture with VGG16 backbone (#3403)
Browse files Browse the repository at this point in the history
Summary:
* Early skeleton of API.

* Adding MultiFeatureMap and vgg16 backbone.

* Making vgg16 backbone same as paper.

* Making code generic to support all vggs.

* Moving vgg's extra layers a separate class + L2 scaling.

* Adding header vgg layers.

* Fix maxpool patching.

* Refactoring code to allow for support of different backbones & sizes:
- Skeleton for Default Boxes generator class
- Dynamic estimation of configuration when possible
- Addition of types

* Complete the implementation of DefaultBox generator.

* Replace randn with empty.

* Minor refactoring

* Making clamping between 0 and 1 optional.

* Change xywh to xyxy encoding.

* Adding parameters and reusing objects in constructor.

* Temporarily inherit from Retina to avoid dup code.

* Implement forward methods + temp workarounds to inherit from retina.

* Inherit more methods from retinanet.

* Fix type error.

* Add Regression loss.

* Fixing JIT issues.

* Change JIT workaround to minimize new code.

* Fixing initialization bug.

* Add classification loss.

* Update todos.

* Add weight loading support.

* Support SSD512.

* Change kernel_size to get output size 1x1

* Add xavier init and refactoring.

* Adding unit-tests and fixing JIT issues.

* Add a test for dbox generator.

* Remove unnecessary import.

* Workaround on GeneralizedRCNNTransform to support fixed size input.

* Remove unnecessary random calls from the test.

* Remove more rand calls from the test.

* change mapping and handling of empty labels

* Fix JIT warnings.

* Speed up loss.

* Convert 0-1 dboxes to original size.

* Fix warning.

* Fix tests.

* Update comments.

* Fixing minor bugs.

* Introduce a custom DBoxMatcher.

* Minor refactoring

* Move extra layer definition inside feature extractor.

* handle no bias on init.

* Remove fixed image size limitation

* Change initialization values for bias of classification head.

* Refactoring and update test file.

* Adding ResNet backbone.

* Minor refactoring.

* Remove inheritance of retina and general refactoring.

* SSD should fix the input size.

* Fixing messages and comments.

* Silently ignoring exception if test-only.

* Update comments.

* Update regression loss.

* Restore Xavier init everywhere, update the negative sampling method, change the clipping approach.

* Fixing tests.

* Refactor to move the losses from the Head to the SSD.

* Removing resnet50 ssd version.

* Adding support for best performing backbone and its config.

* Refactor and clean up the API.

* Fix lint

* Update todos and comments.

* Adding RandomHorizontalFlip and RandomIoUCrop transforms.

* Adding necessary checks to our tranforms.

* Adding RandomZoomOut.

* Adding RandomPhotometricDistort.

* Moving Detection transforms to references.

* Update presets

* fix lint

* leave compose and object

* Adding scaling for completeness.

* Adding params in the repr

* Remove unnecessary import.

* minor refactoring

* Remove unnecessary call.

* Give better names to DBox* classes

* Port num_anchors estimation in generator

* Remove rescaling and fix presets

* Add the ability to pass a custom head and refactoring.

* fix lint

* Fix unit-test

* Update todos.

* Change mean values.

* Change the default parameter of SSD to train the full VGG16 and remove the catch of exception for eval only.

* Adding documentation

* Adding weights and updating readmes.

* Update the model weights with a more performing model.

* Adding doc for head.

* Restore import.

Reviewed By: NicolasHug

Differential Revision: D28169152

fbshipit-source-id: cec34141fad09538e0a29c6eb7834b24e2d8528e
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed May 4, 2021
1 parent f766afa commit 30e7811
Show file tree
Hide file tree
Showing 14 changed files with 1,032 additions and 57 deletions.
19 changes: 14 additions & 5 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,18 @@ Object Detection, Instance Segmentation and Person Keypoint Detection
The models subpackage contains definitions for the following model
architectures for detection:

- `Faster R-CNN ResNet-50 FPN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN ResNet-50 FPN <https://arxiv.org/abs/1703.06870>`_
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_

The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision.

The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images so that they have a minimum size
of ``800``. This option can be changed by passing the option ``min_size``
to the constructor of the models.
The models internally resize the images but the behaviour varies depending
on the model. Check the constructor of the models for more information.


For object detection and instance segmentation, the pre-trained
Expand Down Expand Up @@ -425,6 +426,7 @@ Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD VGG16 25.1 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========

Expand Down Expand Up @@ -483,6 +485,7 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD VGG16 0.2093 0.0744 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Expand All @@ -502,6 +505,12 @@ RetinaNet
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn


SSD
------------

.. autofunction:: torchvision.models.detection.ssd300_vgg16


Mask R-CNN
----------

Expand Down
8 changes: 8 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```

### SSD VGG16
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
```


### Mask R-CNN
```
Expand Down
22 changes: 16 additions & 6 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@


class DetectionPresetTrain:
def __init__(self, hflip_prob=0.5):
trans = [T.ToTensor()]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))

self.transforms = T.Compose(trans)
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
if data_augmentation == 'hflip':
self.transforms = T.Compose([
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
elif data_augmentation == 'ssd':
self.transforms = T.Compose([
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

def __call__(self, img, target):
return self.transforms(img, target)
Expand Down
10 changes: 6 additions & 4 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def get_dataset(name, image_set, transform, data_path):
return ds, num_classes


def get_transform(train):
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
def get_transform(train, data_augmentation):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()


def main(args):
Expand All @@ -60,8 +60,9 @@ def main(args):
# Data loading code
print("Loading data")

dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation),
args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)

print("Creating data loaders")
if args.distributed:
Expand Down Expand Up @@ -179,6 +180,7 @@ def main(args):
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down
230 changes: 210 additions & 20 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import random
import torch
import torchvision

from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional


def _flip_coco_person_keypoints(kps, width):
Expand All @@ -23,27 +27,213 @@ def __call__(self, image, target):
return image, target


class RandomHorizontalFlip(object):
def __init__(self, prob):
self.prob = prob

def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target


class ToTensor(object):
def __call__(self, image, target):
class ToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
return image, target


class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None:
raise ValueError("The targets can't be None for this transform.")

if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_w, orig_h = F._get_image_size(image)

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return image, target

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue

# check for any valid boxes with centers within the crop area
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue

# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]],
dtype=boxes.dtype, device=boxes.device))
if ious.max() < min_jaccard_overlap:
continue

# keep only valid boxes and perform cropping
target["boxes"] = boxes
target["labels"] = target["labels"][is_within_crop_area]
target["boxes"][:, 0::2] -= left
target["boxes"][:, 1::2] -= top
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
image = F.crop(image, top, left, new_h, new_w)

return image, target


class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5):
super().__init__()
if fill is None:
fill = [0., 0., 0.]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
self.p = p

@torch.jit.unused
def _get_fill_value(self, is_pil):
# type: (bool) -> int
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

if torch.rand(1) < self.p:
return image, target

orig_w, orig_h = F._get_image_size(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)

r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)

if torch.jit.is_scripting():
fill = 0
else:
fill = self._get_fill_value(F._is_pil_image(image))

image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor):
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \
image[..., :, (left + orig_w):] = v

if target is not None:
target["boxes"][:, 0::2] += left
target["boxes"][:, 1::2] += top

return image, target


class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5):
super().__init__()
self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast)
self._hue = T.ColorJitter(hue=hue)
self._saturation = T.ColorJitter(saturation=saturation)
self.p = p

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

r = torch.rand(7)

if r[0] < self.p:
image = self._brightness(image)

contrast_before = r[1] < 0.5
if contrast_before:
if r[2] < self.p:
image = self._contrast(image)

if r[3] < self.p:
image = self._saturation(image)

if r[4] < self.p:
image = self._hue(image)

if not contrast_before:
if r[5] < self.p:
image = self._contrast(image)

if r[6] < self.p:
channels = F._get_image_num_channels(image)
permutation = torch.randperm(channels)

is_pil = F._is_pil_image(image)
if is_pil:
image = F.to_tensor(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)

return image, target
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_available_video_models():
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
}


Expand Down
Loading

0 comments on commit 30e7811

Please sign in to comment.