Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge 76fa0b0 into be07c10
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Feb 16, 2021
2 parents be07c10 + 76fa0b0 commit 25500b3
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 34 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed


## [Unreleased] - 2021-02-15

### Added

- Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121))

### Changed



### Fixed



### Removed



## [0.2.0] - 2021-02-12
Expand Down
67 changes: 57 additions & 10 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ The :class:`~flash.vision.ObjectDetector` is already pre-trained on `COCO train2
.. code-block::
annotation{
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories[{
"id": int,
"name": str,
"id": int,
"name": str,
"supercategory": str,
}]
Expand Down Expand Up @@ -70,6 +70,8 @@ Finetuning

To tailor the object detector to your dataset, you would need to have it in `COCO Format <https://cocodataset.org/#format-data>`_, and then finetune the model.

.. tip:: You could also pass `trainable_backbone_layers` to :class:`~flash.vision.ObjectDetector` and train the model.

.. code-block:: python
import flash
Expand All @@ -88,7 +90,7 @@ To tailor the object detector to your dataset, you would need to have it in `COC
)
# 3. Build the model
model = ObjectDetector(num_classes=datamodule.num_classes)
model = ObjectDetector(model="fasterrcnn", backbone="simclr-imagenet", num_classes=datamodule.num_classes)
# 4. Create the trainer. Run thrice on data
trainer = flash.Trainer(max_epochs=3)
Expand All @@ -105,7 +107,52 @@ To tailor the object detector to your dataset, you would need to have it in `COC
Model
*****

By default, we use the `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_ model with a ResNet-50 FPN backbone. The inputs could be images of different sizes. The model behaves differently for training and evaluation. For training, it expects both the input tensors as well as the targets. And during evaluation, it expects only the input tensors and returns predictions for each image. The predictions are a list of boxes, labels and scores.
By default, we use the `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_ model with a ResNet-50 FPN backbone.
We also support `RetinaNet <https://arxiv.org/abs/1708.02002>`_.
The inputs could be images of different sizes.
The model behaves differently for training and evaluation.
For training, it expects both the input tensors as well as the targets. And during the evaluation, it expects only the input tensors and returns predictions for each image.
The predictions are a list of boxes, labels, and scores.

------

*********************
Changing the backbone
*********************
By default, we use a ResNet-50 FPN backbone. You can change the backbone for the model by passing in a different backbone.


.. code-block:: python
# 1. Organize the data
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
batch_size=2
)
# 2. Build the Task
model = ObjectDetector(model="retinanet", backbone="resnet101", num_classes=datamodule.num_classes)
Available backbones:

* resnet18
* resnet34
* resnet50
* resnet101
* resnet152
* resnext50_32x4d
* resnext101_32x8d
* mobilenet_v2
* vgg11
* vgg13
* vgg16
* vgg19
* densenet121
* densenet169
* densenet161
* swav-imagenet
* simclr-imagenet

------

Expand Down
39 changes: 35 additions & 4 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Any, Optional, Tuple

import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV
Expand All @@ -32,12 +33,42 @@
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]


def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]:
def backbone_and_num_features(
model_name: str,
fpn: bool = False,
pretrained: bool = True,
trainable_backbone_layers: int = 3,
**kwargs
) -> Tuple[nn.Module, int]:
"""
Args:
model_name: backbone supported by `torchvision` and `bolts`
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block.
>>> backbone_and_num_features('mobilenet_v2') # doctest: +ELLIPSIS
(Sequential(...), 1280)
>>> backbone_and_num_features('resnet50', fpn=True) # doctest: +ELLIPSIS
(BackboneWithFPN(...), 256)
>>> backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
"""
if fpn:
if model_name in RESNET_MODELS:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_backbone_layers, **kwargs
)
fpn_out_channels = 256
return backbone, fpn_out_channels
else:
rank_zero_warn(f"{model_name} backbone is not supported with `fpn=True`, `fpn` won't be added.")

if model_name in BOLTS_MODELS:
return bolts_backbone_and_num_features(model_name)

if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, *args, **kwargs)
return torchvision_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")

Expand Down
2 changes: 1 addition & 1 deletion flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(

self.save_hyperparameters()

self.backbone, num_features = backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down
89 changes: 75 additions & 14 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from torch import nn
from torch.optim import Optimizer
from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import box_iou

from flash.core import Task
from flash.vision.backbones import backbone_and_num_features
from flash.vision.detection.data import ObjectDetectionDataPipeline
from flash.vision.detection.finetuning import ObjectDetectionFineTuning

_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn}
_models = {
"fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn,
"retinanet": torchvision.models.detection.retinanet_resnet50_fpn,
}


def _evaluate_iou(target, pred):
Expand All @@ -37,14 +44,20 @@ def _evaluate_iou(target, pred):


class ObjectDetector(Task):
"""Image detection task
"""Object detection task
Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts
Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Defaults to 'fasterrcnn_resnet50_fpn'.
model: a string of :attr`_models`. Defaults to 'fasterrcnn'.
backbone: Pretained backbone CNN architecture. Constructs a model with a
ResNet-50-FPN backbone when no backbone is specified.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block.
Only applicable for `fasterrcnn`.
loss: the function(s) to update the model with. Has no effect for torchvision detection models.
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
Expand All @@ -57,23 +70,27 @@ class ObjectDetector(Task):
def __init__(
self,
num_classes: int,
model: Union[str, nn.Module] = "fasterrcnn_resnet50_fpn",
model: str = "fasterrcnn",
backbone: Optional[str] = None,
fpn: bool = True,
pretrained: bool = True,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
anchor_generator: Optional[Type[AnchorGenerator]] = None,
loss=None,
metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None,
optimizer: Type[Optimizer] = torch.optim.Adam,
pretrained: bool = True,
learning_rate=1e-3,
**kwargs,
learning_rate: float = 1e-3,
**kwargs: Any,
):

self.save_hyperparameters()

if model in _models:
model = _models[model](pretrained=pretrained)
if isinstance(model, torchvision.models.detection.FasterRCNN):
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
model = ObjectDetector.get_model(
model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers,
anchor_generator, **kwargs
)
else:
ValueError(f"{model} is not supported yet.")

Expand All @@ -85,6 +102,50 @@ def __init__(
optimizer=optimizer,
)

@staticmethod
def get_model(
model_name, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers,
anchor_generator, **kwargs
):
if backbone is None:
# Constructs a model with a ResNet-50-FPN backbone when no backbone is specified.
if model_name == "fasterrcnn":
model = _models[model_name](
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
else:
model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone)
model.head = RetinaNetHead(
in_channels=model.backbone.out_channels,
num_anchors=model.head.classification_head.num_anchors,
num_classes=num_classes,
**kwargs
)
else:
backbone_model, num_features = backbone_and_num_features(
backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
backbone_model.out_channels = num_features
if anchor_generator is None:
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), )
) if not hasattr(backbone_model, "fpn") else None

if model_name == "fasterrcnn":
model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator)
else:
model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator)
return model

def training_step(self, batch, batch_idx) -> Any:
"""The training step. Overrides ``Task.training_step``
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
assert pooling_fn in [torch.mean, torch.max]
self.pooling_fn = pooling_fn

self.backbone, num_features = backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down
6 changes: 4 additions & 2 deletions tests/vision/detection/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@


@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
def test_detection(tmpdir):
@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", None), ("retinanet", "resnet34"),
("fasterrcnn", "mobilenet_v2"), ("retinanet", "simclr-imagenet")])
def test_detection(tmpdir, model, backbone):

train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)

data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1)
model = ObjectDetector(num_classes=data.num_classes)
model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes)

trainer = flash.Trainer(fast_dev_run=True)

Expand Down
6 changes: 4 additions & 2 deletions tests/vision/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -62,8 +63,9 @@ def test_init():
assert {"boxes", "labels", "scores"} <= out[0].keys()


def test_training(tmpdir):
model = ObjectDetector(num_classes=2, model="fasterrcnn_resnet50_fpn")
@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"])
def test_training(tmpdir, model):
model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False)
ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
dl = DataLoader(ds, collate_fn=collate_fn)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from flash.vision.backbones import backbone_and_num_features


@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280),
("simclr-imagenet", 2048)])
def test_fetch_fasterrcnn_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = backbone_and_num_features(model_name=backbone, pretrained=False, fpn=False)

assert backbone_model
assert num_features == expected_num_features

0 comments on commit 25500b3

Please sign in to comment.