diff --git a/CHANGELOG.md b/CHANGELOG.md index 92ca4b67363..2552dd5b240 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +## [Unreleased] - 2021-02-15 + +### Added + +- Added `RetinaNet` & `backbones` to `ObjectDetector` Task ([#121](https://github.com/PyTorchLightning/lightning-flash/pull/121)) + +### Changed + + + +### Fixed + + + +### Removed + ## [0.2.0] - 2021-02-12 diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 6b9ae98d06f..2840923ca07 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -22,18 +22,18 @@ The :class:`~flash.vision.ObjectDetector` is already pre-trained on `COCO train2 .. code-block:: annotation{ - "id": int, - "image_id": int, - "category_id": int, - "segmentation": RLE or [polygon], - "area": float, - "bbox": [x,y,width,height], + "id": int, + "image_id": int, + "category_id": int, + "segmentation": RLE or [polygon], + "area": float, + "bbox": [x,y,width,height], "iscrowd": 0 or 1, } categories[{ - "id": int, - "name": str, + "id": int, + "name": str, "supercategory": str, }] @@ -70,6 +70,8 @@ Finetuning To tailor the object detector to your dataset, you would need to have it in `COCO Format `_, and then finetune the model. +.. tip:: You could also pass `trainable_backbone_layers` to :class:`~flash.vision.ObjectDetector` and train the model. + .. code-block:: python import flash @@ -88,7 +90,7 @@ To tailor the object detector to your dataset, you would need to have it in `COC ) # 3. Build the model - model = ObjectDetector(num_classes=datamodule.num_classes) + model = ObjectDetector(model="fasterrcnn", backbone="simclr-imagenet", num_classes=datamodule.num_classes) # 4. Create the trainer. Run thrice on data trainer = flash.Trainer(max_epochs=3) @@ -105,7 +107,52 @@ To tailor the object detector to your dataset, you would need to have it in `COC Model ***** -By default, we use the `Faster R-CNN `_ model with a ResNet-50 FPN backbone. The inputs could be images of different sizes. The model behaves differently for training and evaluation. For training, it expects both the input tensors as well as the targets. And during evaluation, it expects only the input tensors and returns predictions for each image. The predictions are a list of boxes, labels and scores. +By default, we use the `Faster R-CNN `_ model with a ResNet-50 FPN backbone. +We also support `RetinaNet `_. +The inputs could be images of different sizes. +The model behaves differently for training and evaluation. +For training, it expects both the input tensors as well as the targets. And during the evaluation, it expects only the input tensors and returns predictions for each image. +The predictions are a list of boxes, labels, and scores. + +------ + +********************* +Changing the backbone +********************* +By default, we use a ResNet-50 FPN backbone. You can change the backbone for the model by passing in a different backbone. + + +.. code-block:: python + + # 1. Organize the data + datamodule = ObjectDetectionData.from_coco( + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + batch_size=2 + ) + + # 2. Build the Task + model = ObjectDetector(model="retinanet", backbone="resnet101", num_classes=datamodule.num_classes) + +Available backbones: + +* resnet18 +* resnet34 +* resnet50 +* resnet101 +* resnet152 +* resnext50_32x4d +* resnext101_32x8d +* mobilenet_v2 +* vgg11 +* vgg13 +* vgg16 +* vgg19 +* densenet121 +* densenet169 +* densenet161 +* swav-imagenet +* simclr-imagenet ------ diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index 8259af09c7d..9269ad21037 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Optional, Tuple import torchvision -from pytorch_lightning.utilities import _BOLTS_AVAILABLE +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn as nn +from torchvision.models.detection.backbone_utils import resnet_fpn_backbone if _BOLTS_AVAILABLE: from pl_bolts.models.self_supervised import SimCLR, SwAV @@ -32,12 +33,42 @@ BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"] -def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]: +def backbone_and_num_features( + model_name: str, + fpn: bool = False, + pretrained: bool = True, + trainable_backbone_layers: int = 3, + **kwargs +) -> Tuple[nn.Module, int]: + """ + Args: + model_name: backbone supported by `torchvision` and `bolts` + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + + >>> backbone_and_num_features('mobilenet_v2') # doctest: +ELLIPSIS + (Sequential(...), 1280) + >>> backbone_and_num_features('resnet50', fpn=True) # doctest: +ELLIPSIS + (BackboneWithFPN(...), 256) + >>> backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS + (Sequential(...), 2048) + """ + if fpn: + if model_name in RESNET_MODELS: + backbone = resnet_fpn_backbone( + model_name, pretrained=pretrained, trainable_layers=trainable_backbone_layers, **kwargs + ) + fpn_out_channels = 256 + return backbone, fpn_out_channels + else: + rank_zero_warn(f"{model_name} backbone is not supported with `fpn=True`, `fpn` won't be added.") + if model_name in BOLTS_MODELS: return bolts_backbone_and_num_features(model_name) if model_name in TORCHVISION_MODELS: - return torchvision_backbone_and_num_features(model_name, *args, **kwargs) + return torchvision_backbone_and_num_features(model_name, pretrained) raise ValueError(f"{model_name} is not supported yet.") diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 4c173d93b69..114175b90bb 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -57,7 +57,7 @@ def __init__( self.save_hyperparameters() - self.backbone, num_features = backbone_and_num_features(backbone, pretrained) + self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained) self.head = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index dead9955152..e8759751dc1 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -11,19 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch import torchvision from torch import nn from torch.optim import Optimizer +from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor +from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead +from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import box_iou from flash.core import Task +from flash.vision.backbones import backbone_and_num_features from flash.vision.detection.data import ObjectDetectionDataPipeline from flash.vision.detection.finetuning import ObjectDetectionFineTuning -_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn} +_models = { + "fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn, + "retinanet": torchvision.models.detection.retinanet_resnet50_fpn, +} def _evaluate_iou(target, pred): @@ -37,14 +44,20 @@ def _evaluate_iou(target, pred): class ObjectDetector(Task): - """Image detection task + """Object detection task Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts Args: num_classes: the number of classes for detection, including background - model: either a string of :attr`_models` or a custom nn.Module. - Defaults to 'fasterrcnn_resnet50_fpn'. + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. loss: the function(s) to update the model with. Has no effect for torchvision detection models. metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. optimizer: The optimizer to use for training. Can either be the actual class or the class name. @@ -57,23 +70,27 @@ class ObjectDetector(Task): def __init__( self, num_classes: int, - model: Union[str, nn.Module] = "fasterrcnn_resnet50_fpn", + model: str = "fasterrcnn", + backbone: Optional[str] = None, + fpn: bool = True, + pretrained: bool = True, + pretrained_backbone: bool = True, + trainable_backbone_layers: int = 3, + anchor_generator: Optional[Type[AnchorGenerator]] = None, loss=None, metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, optimizer: Type[Optimizer] = torch.optim.Adam, - pretrained: bool = True, - learning_rate=1e-3, - **kwargs, + learning_rate: float = 1e-3, + **kwargs: Any, ): self.save_hyperparameters() if model in _models: - model = _models[model](pretrained=pretrained) - if isinstance(model, torchvision.models.detection.FasterRCNN): - in_features = model.roi_heads.box_predictor.cls_score.in_features - head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) - model.roi_heads.box_predictor = head + model = ObjectDetector.get_model( + model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, + anchor_generator, **kwargs + ) else: ValueError(f"{model} is not supported yet.") @@ -85,6 +102,50 @@ def __init__( optimizer=optimizer, ) + @staticmethod + def get_model( + model_name, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, + anchor_generator, **kwargs + ): + if backbone is None: + # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. + if model_name == "fasterrcnn": + model = _models[model_name]( + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + ) + in_features = model.roi_heads.box_predictor.cls_score.in_features + head = FastRCNNPredictor(in_features, num_classes) + model.roi_heads.box_predictor = head + else: + model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) + model.head = RetinaNetHead( + in_channels=model.backbone.out_channels, + num_anchors=model.head.classification_head.num_anchors, + num_classes=num_classes, + **kwargs + ) + else: + backbone_model, num_features = backbone_and_num_features( + backbone, + fpn, + pretrained_backbone, + trainable_backbone_layers, + **kwargs, + ) + backbone_model.out_channels = num_features + if anchor_generator is None: + anchor_generator = AnchorGenerator( + sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) + ) if not hasattr(backbone_model, "fpn") else None + + if model_name == "fasterrcnn": + model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) + else: + model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) + return model + def training_step(self, batch, batch_idx) -> Any: """The training step. Overrides ``Task.training_step`` """ diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index e388cffd964..0e0884d5c80 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -112,7 +112,7 @@ def __init__( assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn - self.backbone, num_features = backbone_and_num_features(backbone, pretrained) + self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained) if embedding_dim is None: self.head = nn.Identity() diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/vision/detection/test_data_model_integration.py index ac814c76168..e014086c940 100644 --- a/tests/vision/detection/test_data_model_integration.py +++ b/tests/vision/detection/test_data_model_integration.py @@ -26,12 +26,14 @@ @pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") -def test_detection(tmpdir): +@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", None), ("retinanet", "resnet34"), + ("fasterrcnn", "mobilenet_v2"), ("retinanet", "simclr-imagenet")]) +def test_detection(tmpdir, model, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) - model = ObjectDetector(num_classes=data.num_classes) + model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True) diff --git a/tests/vision/detection/test_model.py b/tests/vision/detection/test_model.py index 93bc16375c9..70453e6e73f 100644 --- a/tests/vision/detection/test_model.py +++ b/tests/vision/detection/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from pytorch_lightning import Trainer from torch.utils.data import DataLoader, Dataset @@ -62,8 +63,9 @@ def test_init(): assert {"boxes", "labels", "scores"} <= out[0].keys() -def test_training(tmpdir): - model = ObjectDetector(num_classes=2, model="fasterrcnn_resnet50_fpn") +@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"]) +def test_training(tmpdir, model): + model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/vision/test_backbones.py b/tests/vision/test_backbones.py new file mode 100644 index 00000000000..72e5896fe1e --- /dev/null +++ b/tests/vision/test_backbones.py @@ -0,0 +1,13 @@ +import pytest + +from flash.vision.backbones import backbone_and_num_features + + +@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280), + ("simclr-imagenet", 2048)]) +def test_fetch_fasterrcnn_backbone_and_num_features(backbone, expected_num_features): + + backbone_model, num_features = backbone_and_num_features(model_name=backbone, pretrained=False, fpn=False) + + assert backbone_model + assert num_features == expected_num_features