Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Updated the learn2learn "image_classification_imagenette_mini" example #1383

Merged
merged 13 commits into from
Aug 26, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,139 @@

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

"""
Requirements:

pip install learn2learn
pip install kornia
pip install lightning-flash
pip install 'lightning-flash[image]'
"""

krshrimali marked this conversation as resolved.
Show resolved Hide resolved
"""
## Train file
https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1

## Validation File
https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1

Followed by renaming the pickle files
cp './mini-imagenet-cache-train.pkl?dl=1' './mini-imagenet-cache-train.pkl'
cp './mini-imagenet-cache-validation.pkl?dl=1' './mini-imagenet-cache-validation.pkl'

"""

import warnings
from dataclasses import dataclass
from typing import Callable, Tuple, Union

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
from torch import nn
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

warnings.simplefilter("ignore")

# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)
train_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="train", download=False)
val_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="validation", download=False)


@dataclass
class ImageClassificationInputTransform(InputTransform):

transform = {
"per_sample_transform": nn.Sequential(
ApplyToKeys(
image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

def per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
]
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def train_per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
"input",
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
T.ColorJitter(),
T.RandomAutocontrast(),
T.RandomPerspective(),
]
),
),
ApplyToKeys("target", torch.as_tensor),
]
)

def per_batch_transform_on_device(self):
return ApplyToKeys(
DataKeys.INPUT,
nn.Sequential(
torchvision.transforms.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
"per_batch_transform_on_device": ApplyToKeys(
DataKeys.INPUT,
Ka.RandomHorizontalFlip(p=0.25),
),
}
Ka.RandomHorizontalFlip(p=0.25),
)

def collate(self):
return kornia_collate


# construct datamodule

datamodule = ImageClassificationData.from_tensors(
train_data=train_dataset.x,
train_targets=torch.from_numpy(train_dataset.y.astype(int)),
val_data=val_dataset.x,
val_targets=torch.from_numpy(val_dataset.y.astype(int)),
transform=transform,
train_transform=ImageClassificationInputTransform,
val_transform=ImageClassificationInputTransform,
batch_size=1,
)

model = ImageClassifier(
backbone="resnet18",
training_strategy="prototypicalnetworks",
training_strategy_kwargs={
"epoch_length": 10 * 16,
"meta_batch_size": 4,
"meta_batch_size": 1,
"num_tasks": 200,
"test_num_tasks": 2000,
"ways": datamodule.num_classes,
Expand All @@ -92,9 +160,9 @@
)

trainer = flash.Trainer(
max_epochs=200,
gpus=2,
accelerator="ddp_shared",
max_epochs=1,
gpus=1,
precision=16,
)

trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")