diff --git a/.azure-pipelines/gpu-example-tests.yml b/.azure-pipelines/gpu-example-tests.yml index db47dcbc7a..e71374ac4f 100644 --- a/.azure-pipelines/gpu-example-tests.yml +++ b/.azure-pipelines/gpu-example-tests.yml @@ -12,7 +12,8 @@ jobs: parameters: configs: - "image" - - "image,image_extras" + - "icevision" + - "vissl" - "text" - "tabular" - "video" diff --git a/.azure-pipelines/testing-template.yml b/.azure-pipelines/testing-template.yml index 42d132fb02..6c69b810f5 100644 --- a/.azure-pipelines/testing-template.yml +++ b/.azure-pipelines/testing-template.yml @@ -36,7 +36,7 @@ jobs: - bash: | # python -m pip install "pip==20.1" - if [ "${{config}}" == "image,image_extras" ]; then pip install '.[image]' icevision effdet icedata; else pip install '.[${{config}}]'; fi + if [ "${{config}}" == "icevision" ]; then pip install '.[image]' icevision effdet icedata; elif [ "${{config}}" == "vissl" ]; then pip install '.[image]'; else pip install '.[${{config}}]'; fi pip install '.[test]' --upgrade-strategy only-if-needed pip list displayName: 'Install dependencies' @@ -46,11 +46,18 @@ jobs: pip uninstall -y opencv-python-headless pip install opencv-python-headless==4.5.5.64 displayName: 'Install OpenCV dependencies' - condition: eq('${{ config }}', 'image,image_extras') + condition: eq('${{ config }}', 'icevision') + + - bash: | + pip install fairscale + pip install git+https://github.com/facebookresearch/ClassyVision.git + pip install git+https://github.com/facebookresearch/vissl.git + displayName: 'Install VISSL dependencies' + condition: eq('${{ config }}', 'vissl') - bash: | python -c "import torch; print(f'found GPUs: {torch.cuda.device_count()}')" - python -m coverage run --source flash -m pytest tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30 + python -m coverage run --source flash -m pytest tests/examples/test_scripts.py tests/image/embedding/test_model.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30 env: CUDA_VISIBLE_DEVICES: ${{gids}} FLASH_TEST_TOPIC: ${{ config }} diff --git a/CHANGELOG.md b/CHANGELOG.md index bb236747b4..38cc4e7267 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed GPU support for self-supervised training with the `ImageEmbedder` ([#1256](https://github.com/PyTorchLightning/lightning-flash/pull/1256)) + - Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217)) - Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index fe9019cd8b..78f2232c73 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -4,6 +4,10 @@ :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_embedder.svg :tags: Image,Embedding +.. warning:: + + Multi-gpu training is not currently supported by the :class:`~flash.image.embedding.model.ImageEmbedder` task. + .. _image_embedder: ############## @@ -17,7 +21,9 @@ The Task Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification. +The Flash :class:`~flash.image.embedding.model.ImageEmbedder` can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data. The :class:`~flash.image.embedding.model.ImageEmbedder` internally relies on `VISSL `_. +You can read more about our integration with VISSL here: :ref:`vissl`. ------ @@ -26,18 +32,29 @@ Example ******* Let's see how to configure a training strategy for the :class:`~flash.image.embedding.model.ImageEmbedder` task. -A vanilla :class:`~flash.core.data.data_module.DataModule` object be created using standard Datasets as shown below. -Then the user can configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``. -There are options provided to send additional arguments to config selections. -This task can now be sent to the ``fit()`` method of :class:`~flash.core.trainer.Trainer`. - -.. note:: - - A lot of VISSL loss functions use hard-coded ``torch.distributed`` methods. The user is suggested to use ``accelerator=ddp`` even with a single GPU. - Only ``barlow_twins`` training strategy works on the CPU. All other loss functions are configured to work on GPUs. +First we create an :class:`~flash.image.classification.data.ImageClassificationData` object using a `Dataset` from torchvision. +Next, we configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``. +Finally, we construct a :class:`~flash.core.trainer.Trainer` and call ``fit()``. +Here's the full example: .. literalinclude:: ../../../flash_examples/image_embedder.py :language: python :lines: 14- To learn how to view the available backbones / heads for this task, see :ref:`backbones_heads`. +You can view the available training strategies with the :meth:`~flash.image.embedding.model.ImageEmbedder.available_training_strategies` method. + +.. note:: + + The ``"dino"`` training strategy only supports single GPU training with ``strategy="ddp"``. + +The ``head`` and ``pretraining_transform`` arguments should match the choice of ``training_strategy`` following this table: + +===================== ===================== ========================== +``training_strategy`` ``head`` ``pretraining_transform`` +===================== ===================== ========================== +``simclr`` ``simclr_head`` ``simclr_transform`` +``barlow_twins`` ``barlow_twins_head`` ``barlow_twins_transform`` +``swav`` ``swav_head`` ``swav_transform`` +``dino`` ``dino_head`` ``dino_transform`` +===================== ===================== ========================== diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index c5417f73b9..a0b5026c82 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -291,7 +291,7 @@ def _import_module(self): if "FLASH_TEST_TOPIC" in os.environ: topic = os.environ["FLASH_TEST_TOPIC"] _IMAGE_TESTING = topic == "image" - _IMAGE_EXTRAS_TESTING = topic == "image,image_extras" + _IMAGE_EXTRAS_TESTING = topic == "image,image_extras" or topic == "icevision" or topic == "vissl" _VIDEO_TESTING = topic == "video" _VIDEO_EXTRAS_TESTING = topic == "video,video_extras" _TABULAR_TESTING = topic == "tabular" diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 6ca62f5528..edd7e37d93 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -18,21 +18,24 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE, requires from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE +from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES +from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES +from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS if _VISSL_AVAILABLE: import classy_vision import classy_vision.generic.distributed_util - from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES - from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES - from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS - # patch this to avoid classy vision/vissl based distributed training classy_vision.generic.distributed_util.get_world_size = lambda: 1 -else: - IMAGE_EMBEDDER_BACKBONES = FlashRegistry("backbones") - IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") - IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms") + +# Skip doctests if requirements aren't available +__doctest_skip__ = [] +if not _VISSL_AVAILABLE: + __doctest_skip__ += [ + "ImageEmbedder", + "ImageEmbedder.*", + ] class ImageEmbedder(AdapterTask): @@ -130,6 +133,18 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloade @classmethod @requires(["image", "vissl", "fairscale"]) def available_training_strategies(cls) -> List[str]: + """Get the list of available training strategies (passed to the ``training_strategy`` argument) for this + task. + + Examples + ________ + + .. doctest:: + + >>> from flash.image import ImageEmbedder + >>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['barlow_twins', ..., 'swav'] + """ registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) if registry is None: return [] diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py index 119db01974..41996b5c66 100644 --- a/flash/image/embedding/vissl/adapter.py +++ b/flash/image/embedding/vissl/adapter.py @@ -51,6 +51,9 @@ def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None: # set for momentum teacher based hooks self.last_batch = AttrDict({"sample": AttrDict({"input": None, "data_momentum": None})}) + # used in dino + self.additional_log_data = {} + class VISSLAdapter(Adapter, AdaptVISSLHooks): """The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL. diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py index 6bd090a5ef..9176883f59 100644 --- a/flash/image/embedding/vissl/hooks.py +++ b/flash/image/embedding/vissl/hooks.py @@ -81,6 +81,9 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> task.loss.info_criterion.precompute_pos_neg_mask() + # Cast the loss to the correct device / dtype + task.loss.to(lightning_module.device, lightning_module.dtype) + class AdaptVISSLHooks(ModelHooks): def __init__(self, hooks: List[ClassyHook], task) -> None: diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index 7edc360427..5a51d48eba 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -21,23 +21,21 @@ # 1. Download the data and prepare the datamodule datamodule = ImageClassificationData.from_datasets( train_dataset=CIFAR10(".", download=True), - batch_size=16, + batch_size=4, ) # 2. Build the task embedder = ImageEmbedder( - backbone="resnet", + backbone="vision_transformer", training_strategy="barlow_twins", head="barlow_twins_head", pretraining_transform="barlow_twins_transform", training_strategy_kwargs={"latent_embedding_dim": 128}, - pretraining_transform_kwargs={"size_crops": [196]}, + pretraining_transform_kwargs={"size_crops": [32]}, ) # 3. Create the trainer and pre-train the encoder -# use accelerator='ddp' when using GPU(s), -# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp') -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.fit(embedder, datamodule=datamodule) # 4. Save the model! @@ -50,7 +48,8 @@ predict_files=[ "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg", "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg", - ] + ], + batch_size=3, ) embeddings = trainer.predict(embedder, datamodule=datamodule) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index f57746c65c..1162caa91d 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -29,9 +29,10 @@ _TABULAR_TESTING, _TEXT_TESTING, _VIDEO_TESTING, + _VISSL_AVAILABLE, ) from tests.examples.utils import run_test -from tests.helpers.forked import forked +from tests.helpers.decorators import forked root = Path(__file__).parent.parent.parent @@ -56,6 +57,15 @@ "image_classification_multi_label.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), + pytest.param( + "image_embedder.py", + marks=[ + pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="image libraries aren't installed" + ), + pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU"), + ], + ), pytest.param( "object_detection.py", marks=pytest.mark.skipif( @@ -74,7 +84,6 @@ not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), - # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "question_answering.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), diff --git a/tests/helpers/forked.py b/tests/helpers/decorators.py similarity index 100% rename from tests/helpers/forked.py rename to tests/helpers/decorators.py diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 9cd78e8b74..9f26fc8444 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -17,7 +17,12 @@ import torch import flash -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.core.utilities.imports import ( + _IMAGE_AVAILABLE, + _PL_GREATER_EQUAL_1_5_0, + _TORCHVISION_AVAILABLE, + _VISSL_AVAILABLE, +) from flash.image import ImageClassificationData, ImageEmbedder if _TORCHVISION_AVAILABLE: @@ -50,6 +55,7 @@ def test_load_from_checkpoint_dependency_error(): ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt") +@pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU") @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( "backbone, training_strategy, head, pretraining_transform", @@ -70,7 +76,7 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform # moco strategy, transform and head is not added for this test as it doesn't work as of now. datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(16), - predict_dataset=FakeData(4), + predict_dataset=FakeData(8), batch_size=4, ) @@ -81,7 +87,22 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform pretraining_transform=pretraining_transform, ) - trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count()) + kwargs = {} + + # DINO only works with DDP + if training_strategy == "dino": + if _PL_GREATER_EQUAL_1_5_0: + kwargs["strategy"] = "ddp" + else: + kwargs["accelerator"] = "ddp" + + trainer = flash.Trainer( + max_steps=3, + max_epochs=1, + gpus=torch.cuda.device_count(), + **kwargs, + ) + trainer.fit(embedder, datamodule=datamodule) predictions = trainer.predict(embedder, datamodule=datamodule) for prediction_batch in predictions: