Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Updated the learn2learn "image_classification_imagenette_mini" example (
Browse files Browse the repository at this point in the history
#1383)

* Created using Colaboratory

* Removed the notebook

* Updated the script and removed the notebook

* Apply suggestions from code review

* Update flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py

Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
  • Loading branch information
uakarsh and krshrimali authored Aug 26, 2022
1 parent 0e21259 commit 6859fa1
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed the script of integrating `lightning-flash` with `learn2learn` ([#1376](https://github.com/Lightning-AI/lightning-flash/pull/1383))

- Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed examples for BaaL integration by removing usage of `on_<stage>_dataloader` hooks (removed in PL 1.7.0) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,123 @@

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

"""## Train file https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1
## Validation File
https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1
Followed by renaming the pickle files
cp './mini-imagenet-cache-train.pkl?dl=1' './mini-imagenet-cache-train.pkl'
cp './mini-imagenet-cache-validation.pkl?dl=1' './mini-imagenet-cache-validation.pkl'
"""

import warnings
from dataclasses import dataclass
from typing import Tuple, Union

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision
from torch import nn
import torchvision.transforms as T

import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

warnings.simplefilter("ignore")

# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)
train_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="train", download=False)
val_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="validation", download=False)


@dataclass
class ImageClassificationInputTransform(InputTransform):

image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

transform = {
"per_sample_transform": nn.Sequential(
ApplyToKeys(
def per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
]
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def train_per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
T.ColorJitter(),
T.RandomAutocontrast(),
T.RandomPerspective(),
]
),
),
ApplyToKeys("target", torch.as_tensor),
]
)

def per_batch_transform_on_device(self):
return ApplyToKeys(
DataKeys.INPUT,
nn.Sequential(
torchvision.transforms.ToTensor(),
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
),
"collate": kornia_collate,
"per_batch_transform_on_device": ApplyToKeys(
DataKeys.INPUT,
Ka.RandomHorizontalFlip(p=0.25),
),
}
Ka.RandomHorizontalFlip(p=0.25),
)

def collate(self):
return kornia_collate


# construct datamodule

datamodule = ImageClassificationData.from_tensors(
train_data=train_dataset.x,
train_targets=torch.from_numpy(train_dataset.y.astype(int)),
val_data=val_dataset.x,
val_targets=torch.from_numpy(val_dataset.y.astype(int)),
transform=transform,
train_transform=ImageClassificationInputTransform,
val_transform=ImageClassificationInputTransform,
batch_size=1,
)

model = ImageClassifier(
backbone="resnet18",
training_strategy="prototypicalnetworks",
training_strategy_kwargs={
"epoch_length": 10 * 16,
"meta_batch_size": 4,
"meta_batch_size": 1,
"num_tasks": 200,
"test_num_tasks": 2000,
"ways": datamodule.num_classes,
Expand All @@ -92,9 +144,9 @@
)

trainer = flash.Trainer(
max_epochs=200,
gpus=2,
accelerator="ddp_shared",
max_epochs=1,
gpus=1,
precision=16,
)

trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

0 comments on commit 6859fa1

Please sign in to comment.