diff --git a/flash/vision/classification/backbones.py b/flash/vision/backbones.py similarity index 99% rename from flash/vision/classification/backbones.py rename to flash/vision/backbones.py index 7f2ac0e904f..89e68520666 100644 --- a/flash/vision/classification/backbones.py +++ b/flash/vision/backbones.py @@ -27,7 +27,6 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr >>> torchvision_backbone_and_num_features('densenet121') # doctest: +ELLIPSIS (Sequential(...), 1024) """ - model = getattr(torchvision.models, model_name, None) if model is None: raise MisconfigurationException(f"{model_name} is not supported by torchvision") diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 9e7858bf1ae..4a77264b82d 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask -from flash.vision.classification.backbones import torchvision_backbone_and_num_features +from flash.vision.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 59255536788..7a504a2fc58 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -24,20 +24,10 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor +from flash.vision.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import _default_valid_transforms, _pil_loader from flash.vision.embedding.model_map import _load_bolts_model, _models -_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 -_resnet_feats = lambda model: model.fc.in_features # noqa: E731 - -_backbones = { - "resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats), - "resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats), - "resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats), - "resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats), - "resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats), -} - class ImageEmbedderDataPipeline(TaskDataPipeline): """ @@ -129,15 +119,8 @@ def __init__( config = _load_bolts_model(backbone) self.backbone = config['model'] num_features = config['num_features'] - - elif backbone not in _backbones: - raise NotImplementedError(f"Backbone {backbone} is not yet supported") - else: - backbone_fn, split, num_feats = _backbones[backbone] - backbone = backbone_fn(pretrained=pretrained) - self.backbone = split(backbone) - num_features = num_feats(backbone) + self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) if embedding_dim is None: self.head = nn.Identity()