diff --git a/CHANGELOG.md b/CHANGELOG.md index 72a4f2570a..45b362ca70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) + ### Deprecated ### Removed diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index a0b5026c82..861709ff57 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -194,7 +194,9 @@ def decorator(func): if not available: modules = [f"'{module}'" for module in modules] - modules.append(f"'lightning-flash[{','.join(extras)}]'") + + if extras: + modules.append(f"'lightning-flash[{','.join(extras)}]'") @functools.wraps(func) def wrapper(*args, **kwargs): diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 7fc1f91b88..0e61299905 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -72,13 +72,13 @@ class ImageEmbedder(AdapterTask): backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS - required_extras: List[str] = ["image", "vissl", "fairscale"] + required_extras: str = "image" def __init__( self, - training_strategy: str, - head: str, - pretraining_transform: str, + training_strategy: str = "default", + head: Optional[str] = None, + pretraining_transform: Optional[str] = None, backbone: str = "resnet18", pretrained: bool = False, optimizer: OPTIMIZER_TYPE = "Adam", @@ -113,7 +113,7 @@ def __init__( loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs) adapter = metadata["metadata"]["adapter"].from_task( - self, + task=self, loss_fn=loss_fn, backbone=model, head=head, @@ -128,12 +128,16 @@ def __init__( learning_rate=learning_rate, ) - self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) + if pretraining_transform is not None: + warnings.warn( + "Overriding any transforms from the `DataModule` with the pretraining transform: " + f"{pretraining_transform}." + ) + self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - warnings.warn( - "Warning: VISSL ImageEmbedder overrides any user provided transforms" - " with pre-defined transforms for the training strategy." - ) + if "providers" in metadata["metadata"] and metadata["metadata"]["providers"].name == "Facebook Research/vissl": + if pretraining_transform is None: + raise ValueError("Correct pretraining_transform must be set to use VISSL") def forward(self, x: torch.Tensor) -> Any: return self.model(x) diff --git a/flash/image/embedding/strategies/__init__.py b/flash/image/embedding/strategies/__init__.py index 8d010d7bb8..de9823b0a6 100644 --- a/flash/image/embedding/strategies/__init__.py +++ b/flash/image/embedding/strategies/__init__.py @@ -1,5 +1,7 @@ from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.strategies.default import register_default_strategy from flash.image.embedding.strategies.vissl_strategies import register_vissl_strategies # noqa: F401 IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") register_vissl_strategies(IMAGE_EMBEDDER_STRATEGIES) +register_default_strategy(IMAGE_EMBEDDER_STRATEGIES) diff --git a/flash/image/embedding/strategies/default.py b/flash/image/embedding/strategies/default.py new file mode 100644 index 0000000000..2a8fa9db0f --- /dev/null +++ b/flash/image/embedding/strategies/default.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Any, Optional + +import torch + +from flash.core.adapter import Adapter, AdapterTask +from flash.core.data.io.input import DataKeys +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.url_error import catch_url_error + + +class DefaultAdapter(Adapter): + """The ``DefaultAdapter`` is an :class:`~flash.core.adapter.Adapter`.""" + + required_extras: str = "image" + + def __init__(self, backbone: torch.nn.Module): + super().__init__() + + self.backbone = backbone + + @classmethod + @catch_url_error + def from_task( + cls, + task: AdapterTask, + backbone: torch.nn.Module, + **kwargs, + ) -> Adapter: + adapter = cls(backbone) + adapter.__dict__["_task"] = task + return adapter + + def training_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError( + 'Training an `ImageEmbedder` with `strategy="default"` is not supported. ' + "Use a different strategy instead." + ) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError( + 'Validation an `ImageEmbedder` with `strategy="default"` is not supported. ' + "Use a different strategy instead." + ) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError( + 'Testing an `ImageEmbedder` with `strategy="default"` is not supported. ' + "Use a different strategy instead." + ) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch[DataKeys.PREDS] = Task.predict_step( + self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + ) + return batch + + +def default(head: Optional[str] = None, loss_fn: Optional[str] = None, **kwargs): + """Return `(None, None, [])` as loss function, head and hooks. + + Because default strategy only support prediction. + """ + if head is not None: + warnings.warn(f"default strategy has no heads. So given head({head}) is ignored.") + + if loss_fn is not None: + warnings.warn(f"default strategy has no loss functions. So given loss_fn({loss_fn}) is ignored.") + + return None, None, [] + + +def register_default_strategy(register: FlashRegistry): + """Register default strategy to given ``FlashRegistry``.""" + register(default, name="default", adapter=DefaultAdapter) diff --git a/flash/image/embedding/strategies/vissl_strategies.py b/flash/image/embedding/strategies/vissl_strategies.py index 6089ee9dc1..a4f4a1a224 100644 --- a/flash/image/embedding/strategies/vissl_strategies.py +++ b/flash/image/embedding/strategies/vissl_strategies.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.core.utilities.imports import _VISSL_AVAILABLE, requires from flash.core.utilities.providers import _VISSL from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS @@ -23,6 +23,7 @@ from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook +@requires(["vissl", "classy_vision"]) def swav(head: str = "swav_head", **kwargs): loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs) head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) @@ -30,6 +31,7 @@ def swav(head: str = "swav_head", **kwargs): return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook()] +@requires(["vissl", "classy_vision"]) def simclr(head: str = "simclr_head", **kwargs): loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs) head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) @@ -37,6 +39,7 @@ def simclr(head: str = "simclr_head", **kwargs): return loss_fn, head, [SimCLRTrainingSetupHook()] +@requires(["vissl", "classy_vision"]) def barlow_twins(head: str = "barlow_twins_head", **kwargs): loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs) head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) @@ -45,6 +48,5 @@ def barlow_twins(head: str = "barlow_twins_head", **kwargs): def register_vissl_strategies(register: FlashRegistry): - if _VISSL_AVAILABLE: - for training_strategy in (swav, simclr, barlow_twins): - register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL) + for training_strategy in (swav, simclr, barlow_twins): + register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL) diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/flash_examples/integrations/fiftyone/image_embedding.py index 019bd9cffe..9b7382034d 100644 --- a/flash_examples/integrations/fiftyone/image_embedding.py +++ b/flash_examples/integrations/fiftyone/image_embedding.py @@ -14,9 +14,12 @@ import fiftyone as fo import fiftyone.brain as fob import numpy as np +import torch +import flash from flash.core.data.utils import download_data from flash.image import ImageEmbedder +from flash.image.classification.data import ImageClassificationData # 1 Download data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip") @@ -26,13 +29,18 @@ "data/hymenoptera_data/test/", fo.types.ImageClassificationDirectoryTree, ) +datamodule = ImageClassificationData.from_files( + predict_files=dataset.values("filepath"), + batch_size=16, +) # 3 Load model -embedder = ImageEmbedder(backbone="resnet101") +embedder = ImageEmbedder(backbone="resnet18") # 4 Generate embeddings -filepaths = dataset.values("filepath") -embeddings = np.stack(embedder.predict(filepaths)) +trainer = flash.Trainer(gpus=torch.cuda.device_count()) +embedding_batches = trainer.predict(embedder, datamodule=datamodule) +embeddings = np.stack(sum(embedding_batches, [])) # 5 Visualize in FiftyOne App results = fob.compute_visualization(dataset, embeddings=embeddings) diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index c865446b93..fb3ea0d9a5 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -87,3 +87,59 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform for prediction_batch in predictions: for prediction in prediction_batch: assert prediction.size(0) == embedding_size + + +@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.parametrize( + "backbone, training_strategy, head, pretraining_transform, expected_exception", + [ + ("resnet18", "simclr", "simclr_head", None, ValueError), + ("resnet18", "simclr", None, "simclr_transform", KeyError), + ], +) +def test_vissl_training_with_wrong_arguments( + backbone, training_strategy, head, pretraining_transform, expected_exception +): + with pytest.raises(expected_exception): + ImageEmbedder( + backbone=backbone, + training_strategy=training_strategy, + head=head, + pretraining_transform=pretraining_transform, + ) + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.") +@pytest.mark.parametrize( + "backbone, embedding_size", + [ + ("resnet18", 512), + ("vit_small_patch16_224", 384), + ], +) +def test_only_embedding(backbone, embedding_size): + datamodule = ImageClassificationData.from_datasets( + predict_dataset=FakeData(8), + batch_size=4, + transform_kwargs=dict(image_size=(224, 224)), + ) + + embedder = ImageEmbedder(backbone=backbone) + trainer = flash.Trainer() + + predictions = trainer.predict(embedder, datamodule=datamodule) + for prediction_batch in predictions: + for prediction in prediction_batch: + assert prediction.size(0) == embedding_size + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.") +def test_not_implemented_steps(): + embedder = ImageEmbedder(backbone="resnet18") + + with pytest.raises(NotImplementedError): + embedder.training_step([], 0) + with pytest.raises(NotImplementedError): + embedder.validation_step([], 0) + with pytest.raises(NotImplementedError): + embedder.test_step([], 0)