Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Updated the learn2learn "image_classification_imagenette_mini" example #1383

Merged
merged 13 commits into from
Aug 26, 2022
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed the script of integrating `lightning-flash` with `learn2learn` ([#1376](https://github.com/Lightning-AI/lightning-flash/pull/1383))

- Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed examples for BaaL integration by removing usage of `on_<stage>_dataloader` hooks (removed in PL 1.7.0) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,123 @@

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

"""## Train file https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1

## Validation File
https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1

Followed by renaming the pickle files
cp './mini-imagenet-cache-train.pkl?dl=1' './mini-imagenet-cache-train.pkl'
cp './mini-imagenet-cache-validation.pkl?dl=1' './mini-imagenet-cache-validation.pkl'
"""

import warnings
from dataclasses import dataclass
from typing import Tuple, Union

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision
from torch import nn
import torchvision.transforms as T

import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

warnings.simplefilter("ignore")

# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)
train_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="train", download=False)
val_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="validation", download=False)


@dataclass
class ImageClassificationInputTransform(InputTransform):

image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

transform = {
"per_sample_transform": nn.Sequential(
ApplyToKeys(
def per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
]
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def train_per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
T.ColorJitter(),
T.RandomAutocontrast(),
T.RandomPerspective(),
]
),
),
ApplyToKeys("target", torch.as_tensor),
]
)

def per_batch_transform_on_device(self):
return ApplyToKeys(
DataKeys.INPUT,
nn.Sequential(
torchvision.transforms.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
"per_batch_transform_on_device": ApplyToKeys(
DataKeys.INPUT,
Ka.RandomHorizontalFlip(p=0.25),
),
}
Ka.RandomHorizontalFlip(p=0.25),
)

def collate(self):
return kornia_collate


# construct datamodule

datamodule = ImageClassificationData.from_tensors(
train_data=train_dataset.x,
train_targets=torch.from_numpy(train_dataset.y.astype(int)),
val_data=val_dataset.x,
val_targets=torch.from_numpy(val_dataset.y.astype(int)),
transform=transform,
train_transform=ImageClassificationInputTransform,
val_transform=ImageClassificationInputTransform,
batch_size=1,
)

model = ImageClassifier(
backbone="resnet18",
training_strategy="prototypicalnetworks",
training_strategy_kwargs={
"epoch_length": 10 * 16,
"meta_batch_size": 4,
"meta_batch_size": 1,
"num_tasks": 200,
"test_num_tasks": 2000,
"ways": datamodule.num_classes,
Expand All @@ -92,9 +144,9 @@
)

trainer = flash.Trainer(
max_epochs=200,
gpus=2,
accelerator="ddp_shared",
max_epochs=1,
gpus=1,
precision=16,
)

trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")