Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix VISSL on GPU and add VISSL GPU CI #1256

Merged
merged 25 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .azure-pipelines/gpu-example-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
parameters:
configs:
- "image"
- "image,image_extras"
- "icevision"
- "vissl"
- "text"
- "tabular"
- "video"
Expand Down
13 changes: 10 additions & 3 deletions .azure-pipelines/testing-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

- bash: |
# python -m pip install "pip==20.1"
if [ "${{config}}" == "image,image_extras" ]; then pip install '.[image]' icevision effdet icedata; else pip install '.[${{config}}]'; fi
if [ "${{config}}" == "icevision" ]; then pip install '.[image]' icevision effdet icedata; elif [ "${{config}}" == "vissl" ]; then pip install '.[image]'; else pip install '.[${{config}}]'; fi
pip install '.[test]' --upgrade-strategy only-if-needed
pip list
displayName: 'Install dependencies'
Expand All @@ -46,11 +46,18 @@ jobs:
pip uninstall -y opencv-python-headless
pip install opencv-python-headless==4.5.5.64
displayName: 'Install OpenCV dependencies'
condition: eq('${{ config }}', 'image,image_extras')
condition: eq('${{ config }}', 'icevision')

- bash: |
pip install fairscale
pip install git+https://github.com/facebookresearch/ClassyVision.git
pip install git+https://github.com/facebookresearch/vissl.git
displayName: 'Install VISSL dependencies'
condition: eq('${{ config }}', 'vissl')

- bash: |
python -c "import torch; print(f'found GPUs: {torch.cuda.device_count()}')"
python -m coverage run --source flash -m pytest tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
python -m coverage run --source flash -m pytest tests/examples/test_scripts.py tests/image/embedding/test_model.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
env:
CUDA_VISIBLE_DEVICES: ${{gids}}
FLASH_TEST_TOPIC: ${{ config }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed GPU support for self-supervised training with the `ImageEmbedder` ([#1256](https://github.com/PyTorchLightning/lightning-flash/pull/1256))

- Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217))

- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
Expand Down
35 changes: 26 additions & 9 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_embedder.svg
:tags: Image,Embedding

.. warning::

Multi-gpu training is not currently supported by the :class:`~flash.image.embedding.model.ImageEmbedder` task.

ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
.. _image_embedder:

##############
Expand All @@ -17,7 +21,9 @@ The Task
Image embedding encodes an image into a vector of features which can be used for a downstream task.
This could include: clustering, similarity search, or classification.

The Flash :class:`~flash.image.embedding.model.ImageEmbedder` can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data.
The :class:`~flash.image.embedding.model.ImageEmbedder` internally relies on `VISSL <https://vissl.ai/>`_.
You can read more about our integration with VISSL here: :ref:`vissl`.

------

Expand All @@ -26,18 +32,29 @@ Example
*******

Let's see how to configure a training strategy for the :class:`~flash.image.embedding.model.ImageEmbedder` task.
A vanilla :class:`~flash.core.data.data_module.DataModule` object be created using standard Datasets as shown below.
Then the user can configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``.
There are options provided to send additional arguments to config selections.
This task can now be sent to the ``fit()`` method of :class:`~flash.core.trainer.Trainer`.

.. note::

A lot of VISSL loss functions use hard-coded ``torch.distributed`` methods. The user is suggested to use ``accelerator=ddp`` even with a single GPU.
Only ``barlow_twins`` training strategy works on the CPU. All other loss functions are configured to work on GPUs.
First we create an :class:`~flash.image.classification.data.ImageClassificationData` object using a `Dataset` from torchvision.
Next, we configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``.
Finally, we construct a :class:`~flash.core.trainer.Trainer` and call ``fit()``.
Here's the full example:

.. literalinclude:: ../../../flash_examples/image_embedder.py
:language: python
:lines: 14-

To learn how to view the available backbones / heads for this task, see :ref:`backbones_heads`.
You can view the available training strategies with the :meth:`~flash.image.embedding.model.ImageEmbedder.available_training_strategies` method.

.. note::

The ``"dino"`` training strategy only supports single GPU training with ``strategy="DDP"``.

The ``head`` and ``pretraining_transform`` arguments should match the choice of ``training_strategy`` following this table:

===================== ================ ==========================
``training_strategy`` ``head`` ``pretraining_transform``
===================== ================ ==========================
``simclr`` ``simclr_head`` ``simclr_transform``
``barlow_twins`` ``barlow_twins`` ``barlow_twins_transform``
``swav`` ``swav_head`` ``swav_transform``
``dino`` ``dino_head`` ``dino_transform``
===================== ================ ==========================
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _import_module(self):
if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_IMAGE_TESTING = topic == "image"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras" or topic == "icevision" or topic == "vissl"
_VIDEO_TESTING = topic == "video"
_VIDEO_EXTRAS_TESTING = topic == "video,video_extras"
_TABULAR_TESTING = topic == "tabular"
Expand Down
31 changes: 23 additions & 8 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

if _VISSL_AVAILABLE:
import classy_vision
import classy_vision.generic.distributed_util

from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

# patch this to avoid classy vision/vissl based distributed training
classy_vision.generic.distributed_util.get_world_size = lambda: 1
else:
IMAGE_EMBEDDER_BACKBONES = FlashRegistry("backbones")
IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies")
IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms")

# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _VISSL_AVAILABLE:
__doctest_skip__ += [
"ImageEmbedder",
"ImageEmbedder.*",
]


class ImageEmbedder(AdapterTask):
Expand Down Expand Up @@ -130,6 +133,18 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloade
@classmethod
@requires(["image", "vissl", "fairscale"])
def available_training_strategies(cls) -> List[str]:
"""Get the list of available training strategies (passed to the ``training_strategy`` argument) for this
task.

Examples
________

.. doctest::

>>> from flash.image import ImageEmbedder
>>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['barlow_twins', ..., 'swav']
"""
registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None)
if registry is None:
return []
Expand Down
3 changes: 3 additions & 0 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None:
# set for momentum teacher based hooks
self.last_batch = AttrDict({"sample": AttrDict({"input": None, "data_momentum": None})})

# used in dino
self.additional_log_data = {}


class VISSLAdapter(Adapter, AdaptVISSLHooks):
"""The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL.
Expand Down
3 changes: 3 additions & 0 deletions flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

task.loss.info_criterion.precompute_pos_neg_mask()

# Cast the loss to the correct device / dtype
task.loss.to(task.vissl_adapter.adapter_task.device, task.vissl_adapter.adapter_task.dtype)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved


class AdaptVISSLHooks(ModelHooks):
def __init__(self, hooks: List[ClassyHook], task) -> None:
Expand Down
13 changes: 6 additions & 7 deletions flash_examples/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=16,
batch_size=4,
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
)

# 2. Build the task
embedder = ImageEmbedder(
backbone="resnet",
backbone="vision_transformer",
training_strategy="barlow_twins",
head="barlow_twins_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [196]},
pretraining_transform_kwargs={"size_crops": [32]},
)

# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
Expand All @@ -50,7 +48,8 @@
predict_files=[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
]
],
batch_size=3,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

Expand Down
13 changes: 11 additions & 2 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
_TABULAR_TESTING,
_TEXT_TESTING,
_VIDEO_TESTING,
_VISSL_AVAILABLE,
)
from tests.examples.utils import run_test
from tests.helpers.forked import forked
from tests.helpers.decorators import forked

root = Path(__file__).parent.parent.parent

Expand All @@ -56,6 +57,15 @@
"image_classification_multi_label.py",
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
"image_embedder.py",
marks=[
pytest.mark.skipif(
not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="image libraries aren't installed"
),
pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU"),
],
),
pytest.param(
"object_detection.py",
marks=pytest.mark.skipif(
Expand All @@ -74,7 +84,6 @@
not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
),
),
# pytest.param("finetuning", "object_detection.py"), # TODO: takes too long.
pytest.param(
"question_answering.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
Expand Down
File renamed without changes.
27 changes: 24 additions & 3 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import torch

import flash
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE
from flash.core.utilities.imports import (
_IMAGE_AVAILABLE,
_PL_GREATER_EQUAL_1_5_0,
_TORCHVISION_AVAILABLE,
_VISSL_AVAILABLE,
)
from flash.image import ImageClassificationData, ImageEmbedder

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -50,6 +55,7 @@ def test_load_from_checkpoint_dependency_error():
ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU")
@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform",
Expand All @@ -70,7 +76,7 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
# moco strategy, transform and head is not added for this test as it doesn't work as of now.
datamodule = ImageClassificationData.from_datasets(
train_dataset=FakeData(16),
predict_dataset=FakeData(4),
predict_dataset=FakeData(8),
batch_size=4,
)

Expand All @@ -81,7 +87,22 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
pretraining_transform=pretraining_transform,
)

trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count())
kwargs = {}

# DINO only works with DDP
if training_strategy == "dino":
if _PL_GREATER_EQUAL_1_5_0:
kwargs["strategy"] = "DDP"
else:
kwargs["accelerator"] = "DDP"

trainer = flash.Trainer(
max_steps=3,
max_epochs=1,
gpus=torch.cuda.device_count(),
**kwargs,
)

trainer.fit(embedder, datamodule=datamodule)
predictions = trainer.predict(embedder, datamodule=datamodule)
for prediction_batch in predictions:
Expand Down