Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Make dependency of vissl optional (#1276)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
ar90n and ethanwharris committed Apr 8, 2022
1 parent b88d1ba commit f4f14b0
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))

### Deprecated

### Removed
Expand Down
4 changes: 3 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def decorator(func):

if not available:
modules = [f"'{module}'" for module in modules]
modules.append(f"'lightning-flash[{','.join(extras)}]'")

if extras:
modules.append(f"'lightning-flash[{','.join(extras)}]'")

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down
24 changes: 14 additions & 10 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ class ImageEmbedder(AdapterTask):
backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS

required_extras: List[str] = ["image", "vissl", "fairscale"]
required_extras: str = "image"

def __init__(
self,
training_strategy: str,
head: str,
pretraining_transform: str,
training_strategy: str = "default",
head: Optional[str] = None,
pretraining_transform: Optional[str] = None,
backbone: str = "resnet18",
pretrained: bool = False,
optimizer: OPTIMIZER_TYPE = "Adam",
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs)

adapter = metadata["metadata"]["adapter"].from_task(
self,
task=self,
loss_fn=loss_fn,
backbone=model,
head=head,
Expand All @@ -128,12 +128,16 @@ def __init__(
learning_rate=learning_rate,
)

self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
if pretraining_transform is not None:
warnings.warn(
"Overriding any transforms from the `DataModule` with the pretraining transform: "
f"{pretraining_transform}."
)
self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
" with pre-defined transforms for the training strategy."
)
if "providers" in metadata["metadata"] and metadata["metadata"]["providers"].name == "Facebook Research/vissl":
if pretraining_transform is None:
raise ValueError("Correct pretraining_transform must be set to use VISSL")

def forward(self, x: torch.Tensor) -> Any:
return self.model(x)
Expand Down
2 changes: 2 additions & 0 deletions flash/image/embedding/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.strategies.default import register_default_strategy
from flash.image.embedding.strategies.vissl_strategies import register_vissl_strategies # noqa: F401

IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies")
register_vissl_strategies(IMAGE_EMBEDDER_STRATEGIES)
register_default_strategy(IMAGE_EMBEDDER_STRATEGIES)
89 changes: 89 additions & 0 deletions flash/image/embedding/strategies/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Optional

import torch

from flash.core.adapter import Adapter, AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.url_error import catch_url_error


class DefaultAdapter(Adapter):
"""The ``DefaultAdapter`` is an :class:`~flash.core.adapter.Adapter`."""

required_extras: str = "image"

def __init__(self, backbone: torch.nn.Module):
super().__init__()

self.backbone = backbone

@classmethod
@catch_url_error
def from_task(
cls,
task: AdapterTask,
backbone: torch.nn.Module,
**kwargs,
) -> Adapter:
adapter = cls(backbone)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Training an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Validation an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def test_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Testing an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
)
return batch


def default(head: Optional[str] = None, loss_fn: Optional[str] = None, **kwargs):
"""Return `(None, None, [])` as loss function, head and hooks.
Because default strategy only support prediction.
"""
if head is not None:
warnings.warn(f"default strategy has no heads. So given head({head}) is ignored.")

if loss_fn is not None:
warnings.warn(f"default strategy has no loss functions. So given loss_fn({loss_fn}) is ignored.")

return None, None, []


def register_default_strategy(register: FlashRegistry):
"""Register default strategy to given ``FlashRegistry``."""
register(default, name="default", adapter=DefaultAdapter)
10 changes: 6 additions & 4 deletions flash/image/embedding/strategies/vissl_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.providers import _VISSL
from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS
from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS
Expand All @@ -23,20 +23,23 @@
from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook


@requires(["vissl", "classy_vision"])
def swav(head: str = "swav_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook()]


@requires(["vissl", "classy_vision"])
def simclr(head: str = "simclr_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head, [SimCLRTrainingSetupHook()]


@requires(["vissl", "classy_vision"])
def barlow_twins(head: str = "barlow_twins_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)
Expand All @@ -45,6 +48,5 @@ def barlow_twins(head: str = "barlow_twins_head", **kwargs):


def register_vissl_strategies(register: FlashRegistry):
if _VISSL_AVAILABLE:
for training_strategy in (swav, simclr, barlow_twins):
register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL)
for training_strategy in (swav, simclr, barlow_twins):
register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL)
14 changes: 11 additions & 3 deletions flash_examples/integrations/fiftyone/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageEmbedder
from flash.image.classification.data import ImageClassificationData

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
Expand All @@ -26,13 +29,18 @@
"data/hymenoptera_data/test/",
fo.types.ImageClassificationDirectoryTree,
)
datamodule = ImageClassificationData.from_files(
predict_files=dataset.values("filepath"),
batch_size=16,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")
embedder = ImageEmbedder(backbone="resnet18")

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))
trainer = flash.Trainer(gpus=torch.cuda.device_count())
embedding_batches = trainer.predict(embedder, datamodule=datamodule)
embeddings = np.stack(sum(embedding_batches, []))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
Expand Down
56 changes: 56 additions & 0 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,59 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
for prediction_batch in predictions:
for prediction in prediction_batch:
assert prediction.size(0) == embedding_size


@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform, expected_exception",
[
("resnet18", "simclr", "simclr_head", None, ValueError),
("resnet18", "simclr", None, "simclr_transform", KeyError),
],
)
def test_vissl_training_with_wrong_arguments(
backbone, training_strategy, head, pretraining_transform, expected_exception
):
with pytest.raises(expected_exception):
ImageEmbedder(
backbone=backbone,
training_strategy=training_strategy,
head=head,
pretraining_transform=pretraining_transform,
)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.")
@pytest.mark.parametrize(
"backbone, embedding_size",
[
("resnet18", 512),
("vit_small_patch16_224", 384),
],
)
def test_only_embedding(backbone, embedding_size):
datamodule = ImageClassificationData.from_datasets(
predict_dataset=FakeData(8),
batch_size=4,
transform_kwargs=dict(image_size=(224, 224)),
)

embedder = ImageEmbedder(backbone=backbone)
trainer = flash.Trainer()

predictions = trainer.predict(embedder, datamodule=datamodule)
for prediction_batch in predictions:
for prediction in prediction_batch:
assert prediction.size(0) == embedding_size


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.")
def test_not_implemented_steps():
embedder = ImageEmbedder(backbone="resnet18")

with pytest.raises(NotImplementedError):
embedder.training_step([], 0)
with pytest.raises(NotImplementedError):
embedder.validation_step([], 0)
with pytest.raises(NotImplementedError):
embedder.test_step([], 0)

0 comments on commit f4f14b0

Please sign in to comment.