diff --git a/CHANGELOG.md b/CHANGELOG.md index d7a9a00a0f..d0235a2668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed the script of integrating `lightning-flash` with `learn2learn` ([#1376](https://github.com/Lightning-AI/lightning-flash/pull/1383)) + - Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) - Fixed examples for BaaL integration by removing usage of `on__dataloader` hooks (removed in PL 1.7.0) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index d1f5dea28a..8af2ca255d 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -14,63 +14,115 @@ # adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154 +"""## Train file https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1 + +## Validation File +https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1 + +Followed by renaming the pickle files +cp './mini-imagenet-cache-train.pkl?dl=1' './mini-imagenet-cache-train.pkl' +cp './mini-imagenet-cache-validation.pkl?dl=1' './mini-imagenet-cache-validation.pkl' +""" + import warnings +from dataclasses import dataclass +from typing import Tuple, Union import kornia.augmentation as Ka import kornia.geometry as Kg import learn2learn as l2l import torch -import torchvision -from torch import nn +import torchvision.transforms as T import flash from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys, kornia_collate from flash.image import ImageClassificationData, ImageClassifier warnings.simplefilter("ignore") # download MiniImagenet -train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True) -val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True) +train_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="train", download=False) +val_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="validation", download=False) + + +@dataclass +class ImageClassificationInputTransform(InputTransform): + + image_size: Tuple[int, int] = (196, 196) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) -transform = { - "per_sample_transform": nn.Sequential( - ApplyToKeys( + def per_sample_transform(self): + return T.Compose( + [ + ApplyToKeys( + DataKeys.INPUT, + T.Compose( + [ + T.ToTensor(), + Kg.Resize((196, 196)), + # SPATIAL + Ka.RandomHorizontalFlip(p=0.25), + Ka.RandomRotation(degrees=90.0, p=0.25), + Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), + Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), + # PIXEL-LEVEL + Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness + Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation + Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast + Ka.ColorJitter(hue=1 / 30, p=0.25), # hue + Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), + Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), + ] + ), + ), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), + ] + ) + + def train_per_sample_transform(self): + return T.Compose( + [ + ApplyToKeys( + DataKeys.INPUT, + T.Compose( + [ + T.ToTensor(), + T.Resize(self.image_size), + T.Normalize(self.mean, self.std), + T.RandomHorizontalFlip(), + T.ColorJitter(), + T.RandomAutocontrast(), + T.RandomPerspective(), + ] + ), + ), + ApplyToKeys("target", torch.as_tensor), + ] + ) + + def per_batch_transform_on_device(self): + return ApplyToKeys( DataKeys.INPUT, - nn.Sequential( - torchvision.transforms.ToTensor(), - Kg.Resize((196, 196)), - # SPATIAL - Ka.RandomHorizontalFlip(p=0.25), - Ka.RandomRotation(degrees=90.0, p=0.25), - Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), - Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), - # PIXEL-LEVEL - Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness - Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation - Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast - Ka.ColorJitter(hue=1 / 30, p=0.25), # hue - Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), - Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), - ), - ), - ApplyToKeys(DataKeys.TARGET, torch.as_tensor), - ), - "collate": kornia_collate, - "per_batch_transform_on_device": ApplyToKeys( - DataKeys.INPUT, - Ka.RandomHorizontalFlip(p=0.25), - ), -} + Ka.RandomHorizontalFlip(p=0.25), + ) + + def collate(self): + return kornia_collate + # construct datamodule + datamodule = ImageClassificationData.from_tensors( train_data=train_dataset.x, train_targets=torch.from_numpy(train_dataset.y.astype(int)), val_data=val_dataset.x, val_targets=torch.from_numpy(val_dataset.y.astype(int)), - transform=transform, + train_transform=ImageClassificationInputTransform, + val_transform=ImageClassificationInputTransform, + batch_size=1, ) model = ImageClassifier( @@ -78,7 +130,7 @@ training_strategy="prototypicalnetworks", training_strategy_kwargs={ "epoch_length": 10 * 16, - "meta_batch_size": 4, + "meta_batch_size": 1, "num_tasks": 200, "test_num_tasks": 2000, "ways": datamodule.num_classes, @@ -92,9 +144,9 @@ ) trainer = flash.Trainer( - max_epochs=200, - gpus=2, - accelerator="ddp_shared", + max_epochs=1, + gpus=1, precision=16, ) + trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")