From 536dbfc8bff37b8727262ca9af2095c4ebd84335 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 28 Mar 2022 17:35:27 +0530 Subject: [PATCH] Propagate `collate_fn` to `InputTransform` in `ImageEmbedder` (#1217) Co-authored-by: Ethan Harris --- CHANGELOG.md | 6 +- flash/core/adapter.py | 9 ++ flash/core/data/io/input_transform.py | 2 +- flash/image/embedding/heads/vissl_heads.py | 19 ++- flash/image/embedding/losses/vissl_losses.py | 14 ++ flash/image/embedding/model.py | 12 +- .../embedding/transforms/vissl_transforms.py | 21 +-- flash/image/embedding/vissl/hooks.py | 6 +- .../embedding/vissl/transforms/__init__.py | 5 - .../embedding/vissl/transforms/multicrop.py | 121 +++++++++--------- flash_examples/image_embedder.py | 2 +- .../integrations/vissl/test_transforms.py | 41 ------ tests/image/embedding/test_model.py | 16 +-- 13 files changed, 130 insertions(+), 144 deletions(-) delete mode 100644 tests/core/integrations/vissl/test_transforms.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b0eb00981..99643da230 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,9 +20,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) +- Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217)) -- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) +- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) + +- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) - Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213)) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 433fdb74cb..7a159dd2d0 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -89,6 +89,15 @@ def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]: def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None: self.adapter.input_transform = input_transform + @torch.jit.unused + @property + def collate_fn(self) -> Optional[Callable]: + return self.adapter.collate_fn + + @collate_fn.setter + def collate_fn(self, collate_fn: Callable) -> None: + self.adapter.collate_fn = collate_fn + @torch.jit.unused @property def backbone(self) -> nn.Module: diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index d40bec94c5..a86ad28ff0 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -1060,7 +1060,7 @@ def create_or_configure_input_transform( ) return transform(**transform_kwargs) - if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform": + if isinstance(transform, partial): return transform(**transform_kwargs) if isinstance(transform, Callable): diff --git a/flash/image/embedding/heads/vissl_heads.py b/flash/image/embedding/heads/vissl_heads.py index cb2798fe38..4c9b6f4ee9 100644 --- a/flash/image/embedding/heads/vissl_heads.py +++ b/flash/image/embedding/heads/vissl_heads.py @@ -89,13 +89,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def simclr_head( - dims: List[int] = [2048, 2048, 256], + num_features: int = 2048, + embedding_dim: int = 128, + dims: List[int] = [2048], use_bn: bool = True, **kwargs, ) -> nn.Module: cfg = VISSLAdapter.get_model_config_template() head_kwargs = { - "dims": dims, + "dims": [num_features] + dims + [embedding_dim], "use_bn": use_bn, } @@ -108,7 +110,9 @@ def simclr_head( def swav_head( - dims: List[int] = [2048, 2048, 128], + num_features: int = 2048, + embedding_dim: int = 128, + dims: List[int] = [2048], use_bn: bool = True, num_clusters: Union[int, List[int]] = [3000], use_bias: bool = True, @@ -121,7 +125,7 @@ def swav_head( ) -> nn.Module: cfg = VISSLAdapter.get_model_config_template() head_kwargs = { - "dims": dims, + "dims": [num_features] + dims + [embedding_dim], "use_bn": use_bn, "num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters, "use_bias": use_bias, @@ -140,8 +144,11 @@ def swav_head( return head -def barlow_twins_head(**kwargs) -> nn.Module: - return simclr_head(**kwargs) +def barlow_twins_head( + latent_embedding_dim: int = 8192, + **kwargs, +) -> nn.Module: + return simclr_head(embedding_dim=latent_embedding_dim, **kwargs) def moco_head(**kwargs) -> nn.Module: diff --git a/flash/image/embedding/losses/vissl_losses.py b/flash/image/embedding/losses/vissl_losses.py index b1ba8f936b..d63d0e781a 100644 --- a/flash/image/embedding/losses/vissl_losses.py +++ b/flash/image/embedding/losses/vissl_losses.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import List, Union +import torch.cuda + from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE @@ -26,11 +28,23 @@ ClassyLoss = object +def _recursive_register(module): + named_tensors = [(key, value) for key, value in module.__dict__.items() if isinstance(value, torch.Tensor)] + for name, tensor in named_tensors: + delattr(module, name) + module.register_buffer(name, tensor) + + for child_module in module.modules(): + if child_module is not module: + _recursive_register(child_module) + + def get_loss_fn(loss_name: str, cfg: AttrDict): set_cpu_device() loss_fn = LOSS_REGISTRY[loss_name](cfg) loss_fn.__dict__["loss_name"] = loss_name + _recursive_register(loss_fn) return loss_fn diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 54d01ca043..6ca62f5528 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -12,13 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from functools import partial from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.io.input import DataKeys -from flash.core.data.io.input_transform import LambdaInputTransform -from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE, requires from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE @@ -92,10 +88,10 @@ def __init__( if pretraining_transform_kwargs is None: pretraining_transform_kwargs = {} - backbone, _ = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) metadata = self.training_strategies.get(training_strategy, with_metadata=True) - loss_fn, head, hooks = metadata["fn"](head=head, **training_strategy_kwargs) + loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -112,9 +108,7 @@ def __init__( learning_rate=learning_rate, ) - input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - output = ApplyToKeys(DataKeys.INPUT, input_transform) - self.input_transform = partial(LambdaInputTransform, transform=output) + self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) warnings.warn( "Warning: VISSL ImageEmbedder overrides any user provided transforms" diff --git a/flash/image/embedding/transforms/vissl_transforms.py b/flash/image/embedding/transforms/vissl_transforms.py index 8e54354a4f..e6be31bd6f 100644 --- a/flash/image/embedding/transforms/vissl_transforms.py +++ b/flash/image/embedding/transforms/vissl_transforms.py @@ -17,11 +17,8 @@ import torch.nn as nn from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _VISSL_AVAILABLE from flash.image.embedding.vissl.transforms import moco_collate_fn, multicrop_collate_fn, simclr_collate_fn - -if _VISSL_AVAILABLE: - from classy_vision.dataset.transforms import TRANSFORM_REGISTRY +from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform def simclr_transform( @@ -33,9 +30,10 @@ def simclr_transform( jitter_strength: float = 1.0, normalize: Optional[nn.Module] = None, collate_fn: Callable = simclr_collate_fn, -) -> nn.Module: +) -> partial: """For simclr, barlow twins and moco.""" - transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + transform = partial( + StandardMultiCropSSLTransform, total_num_crops=total_num_crops, num_crops=num_crops, size_crops=size_crops, @@ -43,9 +41,10 @@ def simclr_transform( gaussian_blur=gaussian_blur, jitter_strength=jitter_strength, normalize=normalize, + collate_fn=collate_fn, ) - return transform, collate_fn + return transform def swav_transform( @@ -57,9 +56,10 @@ def swav_transform( jitter_strength: float = 1.0, normalize: Optional[nn.Module] = None, collate_fn: Callable = multicrop_collate_fn, -) -> nn.Module: +) -> partial: """For swav and dino.""" - transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + transform = partial( + StandardMultiCropSSLTransform, total_num_crops=total_num_crops, num_crops=num_crops, size_crops=size_crops, @@ -67,9 +67,10 @@ def swav_transform( gaussian_blur=gaussian_blur, jitter_strength=jitter_strength, normalize=normalize, + collate_fn=collate_fn, ) - return transform, collate_fn + return transform barlow_twins_transform = partial(simclr_transform, collate_fn=simclr_collate_fn) diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py index d9e7369973..6bd090a5ef 100644 --- a/flash/image/embedding/vissl/hooks.py +++ b/flash/image/embedding/vissl/hooks.py @@ -49,7 +49,11 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> # get around vissl distributed training by setting MockTask flags num_nodes = lightning_module.trainer.num_nodes - accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids + accelerators_ids = getattr( + lightning_module.trainer, + "device_ids", + getattr(accelerator_connector(lightning_module.trainer), "parallel_device_ids", None), + ) accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1 task.world_size = num_nodes * accelerator_per_node diff --git a/flash/image/embedding/vissl/transforms/__init__.py b/flash/image/embedding/vissl/transforms/__init__.py index 447aef4fa7..d09a06e502 100644 --- a/flash/image/embedding/vissl/transforms/__init__.py +++ b/flash/image/embedding/vissl/transforms/__init__.py @@ -5,8 +5,3 @@ multicrop_collate_fn, simclr_collate_fn, ) - -if _VISSL_AVAILABLE: - from classy_vision.dataset.transforms import register_transform # noqa: F401 - - register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform) diff --git a/flash/image/embedding/vissl/transforms/multicrop.py b/flash/image/embedding/vissl/transforms/multicrop.py index 49207ecf96..b9fe2c557a 100644 --- a/flash/image/embedding/vissl/transforms/multicrop.py +++ b/flash/image/embedding/vissl/transforms/multicrop.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Dict, Optional, Sequence -import numpy as np import torch.nn as nn +from torch.utils.data._utils.collate import default_collate -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image +from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import InputTransform +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision.transforms as pth_transforms -class StandardMultiCropSSLTransform(nn.Module): +@dataclass +class StandardMultiCropSSLTransform(InputTransform): """Convert a PIL image to Multi-resolution Crops. The input is a PIL image and output is the list of image crops. @@ -31,51 +36,44 @@ class StandardMultiCropSSLTransform(nn.Module): This transform has been modified from the ImgPilToMultiCrop code present at https://github.com/facebookresearch/vissl/blob/master/vissl/data/ssl_transforms/img_pil_to_multicrop.py + + Args: + total_num_crops (int): Total number of crops to extract + num_crops (List or Tuple of ints): Specifies the number of `type' of crops. + size_crops (List or Tuple of ints): Specifies the height (height = width) of each patch + crop_scales (List or Tuple containing [float, float]): Scale of the crop + gaussian_blur (bool): Specifies if the transforms' composition has Gaussian Blur + jitter_strength (float): Specify the coefficient for color jitter transform + normalize (Optional): Normalize transform from torchvision with params set according to the dataset """ - def __init__( - self, - total_num_crops: int, - num_crops: Sequence[int], - size_crops: Sequence[int], - crop_scales: Sequence[Sequence[float]], - gaussian_blur: bool = True, - jitter_strength: float = 1.0, - normalize: Optional[nn.Module] = None, - ): - """Returns total_num_crops square crops of an image. Each crop is a random crop extracted according to the - parameters specified in size_crops and crop_scales. For ease of use, one can specify `num_crops` which - removes the need to repeat parameters. - - Args: - total_num_crops (int): Total number of crops to extract - num_crops (List or Tuple of ints): Specifies the number of `type' of crops. - size_crops (List or Tuple of ints): Specifies the height (height = width) - of each patch - crop_scales (List or Tuple containing [float, float]): Scale of the crop - gaussian_blur (bool): Specifies if the transforms' composition has Gaussian Blur - jitter_strength (float): Specify the coefficient for color jitter transform - normalize (Optional): Normalize transform from torchvision with params set - according to the dataset - - Example usage: - - (total_num_crops=2, num_crops=[1, 1], - size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)]) - Extracts 2 crops total of size 224x224 and 96x96 - - (total_num_crops=3, num_crops=[1, 2], - size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)]) - Extracts 3 crops total: 1 of size 224x224 and 2 of size 96x96 - """ - super().__init__() - - assert np.sum(num_crops) == total_num_crops - assert len(size_crops) == len(num_crops) - assert len(size_crops) == len(crop_scales) - - self.gaussian_blur = gaussian_blur - self.jitter_strength = jitter_strength - self.normalize = normalize + total_num_crops: int = 2 + num_crops: Sequence[int] = field(default_factory=lambda: [2]) + size_crops: Sequence[int] = field(default_factory=lambda: [224]) + crop_scales: Sequence[Sequence[float]] = field(default_factory=lambda: [[0.4, 1]]) + gaussian_blur: bool = True + jitter_strength: float = 1.0 + normalize: Optional[nn.Module] = None + collate_fn: Callable = default_collate + + @staticmethod + def _apply(transform, sample: Dict[str, Any]) -> Dict[str, Any]: + sample[DataKeys.INPUT] = transform(sample[DataKeys.INPUT]) + return sample + + @staticmethod + def _parallel_apply(transforms, sample: Dict[str, Any]) -> Dict[str, Any]: + sample[DataKeys.INPUT] = [transform(sample[DataKeys.INPUT]) for transform in transforms] + return sample + + def _get_final_transform(self) -> Callable: + if self.normalize is None: + final_transform = pth_transforms.ToTensor() + else: + final_transform = pth_transforms.Compose([pth_transforms.ToTensor(), self.normalize]) + return final_transform + def per_sample_transform(self) -> Callable: color_jitter = pth_transforms.ColorJitter( 0.8 * self.jitter_strength, 0.8 * self.jitter_strength, @@ -85,7 +83,7 @@ def __init__( color_transform = [pth_transforms.RandomApply([color_jitter], p=0.8), pth_transforms.RandomGrayscale(p=0.2)] if self.gaussian_blur: - kernel_size = int(0.1 * size_crops[0]) + kernel_size = int(0.1 * self.size_crops[0]) if kernel_size % 2 == 0: kernel_size += 1 @@ -93,31 +91,36 @@ def __init__( pth_transforms.RandomApply([pth_transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5) ) - self.color_transform = pth_transforms.Compose(color_transform) + color_transform = pth_transforms.Compose(color_transform) - if normalize is None: - self.final_transform = pth_transforms.ToTensor() - else: - self.final_transform = pth_transforms.Compose([pth_transforms.ToTensor(), normalize]) + final_transform = self._get_final_transform() transforms = [] - for num, size, scale in zip(num_crops, size_crops, crop_scales): + for num, size, scale in zip(self.num_crops, self.size_crops, self.crop_scales): transforms.extend( [ pth_transforms.Compose( [ pth_transforms.RandomResizedCrop(size, scale=scale), pth_transforms.RandomHorizontalFlip(p=0.5), - self.color_transform, - self.final_transform, + color_transform, + final_transform, ] ) ] * num ) - self.transforms = transforms + return partial(self._parallel_apply, transforms) + + def collate(self) -> Callable: + return self.collate_fn + + def predict_per_sample_transform(self) -> Callable: + return partial( + self._apply, + pth_transforms.Compose([pth_transforms.CenterCrop(self.size_crops[0]), self._get_final_transform()]), + ) - def __call__(self, image: Image.Image) -> List[Image.Image]: - images = [transform(image) for transform in self.transforms] - return images + def predict_collate(self) -> Callable: + return default_collate diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index 5f272a834d..7edc360427 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -28,7 +28,7 @@ embedder = ImageEmbedder( backbone="resnet", training_strategy="barlow_twins", - head="simclr_head", + head="barlow_twins_head", pretraining_transform="barlow_twins_transform", training_strategy_kwargs={"latent_embedding_dim": 128}, pretraining_transform_kwargs={"size_crops": [196]}, diff --git a/tests/core/integrations/vissl/test_transforms.py b/tests/core/integrations/vissl/test_transforms.py deleted file mode 100644 index e4264c55c0..0000000000 --- a/tests/core/integrations/vissl/test_transforms.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest - -from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE -from tests.image.embedding.utils import ssl_datamodule - - -@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") -def test_multicrop_input_transform(): - batch_size = 8 - total_num_crops = 6 - num_crops = [2, 4] - size_crops = [160, 96] - crop_scales = [[0.4, 1], [0.05, 0.4]] - - datamodule = ssl_datamodule( - batch_size=batch_size, - total_num_crops=total_num_crops, - num_crops=num_crops, - size_crops=size_crops, - crop_scales=crop_scales, - ) - batch = next(iter(datamodule.train_dataloader())) - - assert len(batch[DataKeys.INPUT]) == total_num_crops - assert batch[DataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) - assert batch[DataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) - assert list(batch[DataKeys.TARGET].shape) == [batch_size] diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index cb042e0053..9cd78e8b74 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -62,30 +62,28 @@ def test_load_from_checkpoint_dependency_error(): "dino_transform", marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="VISSL DINO calls all_reduce internally."), ), - ("vision_transformer", "barlow_twins", "simclr_head", "barlow_twins_transform"), + ("vision_transformer", "barlow_twins", "barlow_twins_head", "barlow_twins_transform"), ("vision_transformer", "swav", "swav_head", "swav_transform"), ], ) def test_vissl_training(backbone, training_strategy, head, pretraining_transform): # moco strategy, transform and head is not added for this test as it doesn't work as of now. datamodule = ImageClassificationData.from_datasets( - train_dataset=FakeData(), + train_dataset=FakeData(16), + predict_dataset=FakeData(4), batch_size=4, ) - training_strategy_kwargs = { - "dims": [384, 2048, 2048, 256], - } - dim_key = "latent_embedding_dim" if training_strategy == "barlow_twins" else "embedding_dim" - training_strategy_kwargs[dim_key] = 256 - embedder = ImageEmbedder( backbone=backbone, training_strategy=training_strategy, head=head, pretraining_transform=pretraining_transform, - training_strategy_kwargs=training_strategy_kwargs, ) trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count()) trainer.fit(embedder, datamodule=datamodule) + predictions = trainer.predict(embedder, datamodule=datamodule) + for prediction_batch in predictions: + for prediction in prediction_batch: + assert prediction.size(0) == 384