Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

feat: Add Retinanet and backbones for detection #121

Merged
merged 20 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Any, Optional, Tuple

import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV
Expand All @@ -32,12 +33,36 @@
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]


def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]:
def backbone_and_num_features(
model_name: str,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
fpn: bool = False,
pretrained: bool = True,
trainable_backbone_layers: int = 3,
**kwargs
) -> Tuple[nn.Module, int]:
"""
>>> backbone_and_num_features('mobilenet_v2') # doctest: +ELLIPSIS
(Sequential(...), 1280)
>>> backbone_and_num_features('resnet50', fpn=True) # doctest: +ELLIPSIS
(BackboneWithFPN(...), 256)
>>> backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
"""
if fpn:
if model_name in RESNET_MODELS:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_backbone_layers, **kwargs
)
fpn_out_channels = 256
Borda marked this conversation as resolved.
Show resolved Hide resolved
return backbone, fpn_out_channels
else:
rank_zero_warn(f"{model_name} backbone is not supported with `fpn=True`, `fpn` won't be added.")

if model_name in BOLTS_MODELS:
return bolts_backbone_and_num_features(model_name)

if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, *args, **kwargs)
return torchvision_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")

Expand Down
2 changes: 1 addition & 1 deletion flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(

self.save_hyperparameters()

self.backbone, num_features = backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down
85 changes: 71 additions & 14 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from torch import nn
from torch.optim import Optimizer
from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import box_iou

from flash.core import Task
from flash.vision.backbones import backbone_and_num_features
from flash.vision.detection.data import ObjectDetectionDataPipeline
from flash.vision.detection.finetuning import ObjectDetectionFineTuning

_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn}
_models = {
"fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn,
"retinanet": torchvision.models.detection.retinanet_resnet50_fpn,
}


def _evaluate_iou(target, pred):
Expand All @@ -37,14 +44,20 @@ def _evaluate_iou(target, pred):


class ObjectDetector(Task):
"""Image detection task
"""Object detection task

Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts

Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Defaults to 'fasterrcnn_resnet50_fpn'.
model: a string of :attr`_models`. Defaults to 'fasterrcnn'.
backbone: Pretained backbone CNN architecture. Constructs a model with a
ResNet-50-FPN backbone when no backbone is specified.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block.
Only applicable for `fasterrcnn`.
loss: the function(s) to update the model with. Has no effect for torchvision detection models.
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
Expand All @@ -57,23 +70,25 @@ class ObjectDetector(Task):
def __init__(
self,
num_classes: int,
model: Union[str, nn.Module] = "fasterrcnn_resnet50_fpn",
model: str = "fasterrcnn",
backbone: Optional[str] = None,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
fpn: bool = True,
pretrained: bool = True,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic could/should be within FinetuningCallback.

If the user requires model= fasterrcnn, then it should choose the FasterRCNNFinetuning Callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the models will have the same FineTuningCallback as they would have similar backbones, but different heads. But yes, could think of moving trainable_backbone_layers for FineTuningCallback OR we could offer some options for finetuning functionalities to the User and it would override the trainable_backbone_layers.

loss=None,
metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None,
optimizer: Type[Optimizer] = torch.optim.Adam,
pretrained: bool = True,
learning_rate=1e-3,
**kwargs,
learning_rate: float = 1e-3,
**kwargs: Any,
):

self.save_hyperparameters()

if model in _models:
model = _models[model](pretrained=pretrained)
if isinstance(model, torchvision.models.detection.FasterRCNN):
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
model = ObjectDetector.get_model(
model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, **kwargs
)
else:
ValueError(f"{model} is not supported yet.")

Expand All @@ -85,6 +100,48 @@ def __init__(
optimizer=optimizer,
)

@staticmethod
def get_model(
model_name, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, **kwargs
):
if backbone is None:
# Constructs a model with a ResNet-50-FPN backbone when no backbone is specified.
if model_name == "fasterrcnn":
model = _models[model_name](
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
else:
model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone)
model.head = RetinaNetHead(
in_channels=model.backbone.out_channels,
num_anchors=model.head.classification_head.num_anchors,
num_classes=num_classes,
**kwargs
)
else:
backbone_model, num_features = backbone_and_num_features(
backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
backbone_model.out_channels = num_features
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ),
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
aspect_ratios=((0.5, 1.0,
2.0), )) if not hasattr(backbone_model, "fpn") else None

if model_name == "fasterrcnn":
model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator)
else:
model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator)
return model

def training_step(self, batch, batch_idx) -> Any:
"""The training step. Overrides ``Task.training_step``
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
assert pooling_fn in [torch.mean, torch.max]
self.pooling_fn = pooling_fn

self.backbone, num_features = backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down
6 changes: 4 additions & 2 deletions tests/vision/detection/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@


@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
def test_detection(tmpdir):
@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", None), ("retinanet", "resnet34"),
("fasterrcnn", "mobilenet_v2"), ("retinanet", "simclr-imagenet")])
def test_detection(tmpdir, model, backbone):

train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)

data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1)
model = ObjectDetector(num_classes=data.num_classes)
model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes)

trainer = flash.Trainer(fast_dev_run=True)

Expand Down
6 changes: 4 additions & 2 deletions tests/vision/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -62,8 +63,9 @@ def test_init():
assert {"boxes", "labels", "scores"} <= out[0].keys()


def test_training(tmpdir):
model = ObjectDetector(num_classes=2, model="fasterrcnn_resnet50_fpn")
@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"])
def test_training(tmpdir, model):
model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False)
ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
dl = DataLoader(ds, collate_fn=collate_fn)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from flash.vision.backbones import backbone_and_num_features


@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280),
("simclr-imagenet", 2048)])
def test_fetch_fasterrcnn_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = backbone_and_num_features(model_name=backbone, pretrained=False, fpn=False)

assert backbone_model
assert num_features == expected_num_features