diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index 89e68520666..c492f0015db 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -13,9 +13,67 @@ # limitations under the License. from typing import Tuple -import torch.nn as nn import torchvision +from pytorch_lightning.utilities import _BOLTS_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn as nn + +if _BOLTS_AVAILABLE: + from pl_bolts.models.self_supervised import SimCLR, SwAV + +ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" + +MOBILENET_MODELS = ["mobilenet_v2"] +VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] +RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"] +DENSENET_MODELS = ["densenet121", "densenet169", "densenet161", "densenet161"] +TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS + +BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"] + + +def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]: + if model_name in BOLTS_MODELS: + return bolts_backbone_and_num_features(model_name) + + if model_name in TORCHVISION_MODELS: + return torchvision_backbone_and_num_features(model_name, *args, **kwargs) + + raise ValueError(f"{model_name} is not supported yet.") + + +def bolts_backbone_and_num_features(model_name: str) -> Tuple[nn.Module, int]: + """ + >>> bolts_backbone_and_num_features('simclr-imagenet') # doctest: +ELLIPSIS + (Sequential(...), 2048) + >>> bolts_backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS + (Sequential(...), 3000) + """ + + # TODO: maybe we should plain pytorch weights so we don't need to rely on bolts to load these + # also mabye just use torchhub for the ssl lib + def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"): + simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) + # remove the last two layers & turn it into a Sequential model + backbone = nn.Sequential(*list(simclr.encoder.children())[:-2]) + return backbone, 2048 + + def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"): + swav = SwAV.load_from_checkpoint(path_or_url, strict=True) + # remove the last two layers & turn it into a Sequential model + backbone = nn.Sequential(*list(swav.model.children())[:-2]) + return backbone, 3000 + + models = { + 'simclr-imagenet': load_simclr_imagenet, + 'swav-imagenet': load_swav_imagenet, + } + if not _BOLTS_AVAILABLE: + raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.") + if model_name in models: + return models[model_name]() + + raise ValueError(f"{model_name} is not supported yet.") def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: @@ -31,22 +89,20 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr if model is None: raise MisconfigurationException(f"{model_name} is not supported by torchvision") - if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: + if model_name in MOBILENET_MODELS + VGG_MODELS: model = model(pretrained=pretrained) backbone = model.features num_features = model.classifier[-1].in_features return backbone, num_features - elif model_name in [ - "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" - ]: + elif model_name in RESNET_MODELS: model = model(pretrained=pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features return backbone, num_features - elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: + elif model_name in DENSENET_MODELS: model = model(pretrained=pretrained) backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) num_features = model.classifier.in_features diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 4a77264b82d..69a3fd8c859 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask -from flash.vision.backbones import torchvision_backbone_and_num_features +from flash.vision.backbones import backbone_and_num_features from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline @@ -57,7 +57,7 @@ def __init__( self.save_hyperparameters() - self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) + self.backbone, num_features = backbone_and_num_features(backbone, pretrained) self.head = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 7a504a2fc58..dd0fc7d6e3f 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -24,9 +24,8 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor -from flash.vision.backbones import torchvision_backbone_and_num_features +from flash.vision.backbones import backbone_and_num_features from flash.vision.classification.data import _default_valid_transforms, _pil_loader -from flash.vision.embedding.model_map import _load_bolts_model, _models class ImageEmbedderDataPipeline(TaskDataPipeline): @@ -115,12 +114,7 @@ def __init__( assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn - if backbone in _models: - config = _load_bolts_model(backbone) - self.backbone = config['model'] - num_features = config['num_features'] - else: - self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) + self.backbone, num_features = backbone_and_num_features(backbone, pretrained) if embedding_dim is None: self.head = nn.Identity() diff --git a/flash/vision/embedding/model_map.py b/flash/vision/embedding/model_map.py deleted file mode 100644 index 4565440ea99..00000000000 --- a/flash/vision/embedding/model_map.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import suppress - -from pytorch_lightning.utilities import _BOLTS_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _BOLTS_AVAILABLE: - with suppress(TypeError): - from pl_bolts.models.self_supervised import SimCLR, SwAV - -ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" - - -def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"): - simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) - model_config = {'model': simclr.encoder, 'emb_size': 2048} - return model_config - - -def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"): - swav = SwAV.load_from_checkpoint(path_or_url, strict=True) - model_config = {'model': swav.model, 'num_features': 3000} - return model_config - - -_models = { - 'simclr-imagenet': load_simclr_imagenet, - 'swav-imagenet': load_swav_imagenet, -} - - -def _load_bolts_model(name): - if not _BOLTS_AVAILABLE: - raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.") - if name in _models: - return _models[name]() - raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index b8d04fd47cb..c419a22a966 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -51,7 +51,7 @@ def test_init_train(tmpdir, backbone): def test_non_existent_backbone(): - with pytest.raises(MisconfigurationException): + with pytest.raises(ValueError): ImageClassifier(2, "i am never going to implement this lol") diff --git a/tests/vision/test_download.py b/tests/vision/test_download.py index 2ae3f06d971..f1448a419f7 100644 --- a/tests/vision/test_download.py +++ b/tests/vision/test_download.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest -from flash.vision.embedding.model_map import _load_bolts_model +from flash.vision.backbones import bolts_backbone_and_num_features @pytest.mark.parametrize("name", ['simclr-imagenet', 'swav-imagenet']) def test_load_bolts(name): - _load_bolts_model(name) + bolts_backbone_and_num_features(name)