Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: Migrate vit gradcam code + pytorch export format for classificat… #7

Merged
merged 22 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ea26882
add: Migrate vit gradcam code + pytorch export format for classificat…
AlessandroPolidori Jun 1, 2023
e873068
fix: Fix docs classification tutorial
AlessandroPolidori Jun 1, 2023
4e6380a
fix: Fix sklearn classification docs tutorial
AlessandroPolidori Jun 1, 2023
4609b0d
fix: Set gradcam to false if imported model is torchscript
AlessandroPolidori Jun 1, 2023
5539075
fix: add try except to raise clear error in sklear-test pytorch model…
AlessandroPolidori Jun 1, 2023
ac9f38b
fix: set device=cpu in all classification tests
AlessandroPolidori Jun 1, 2023
f999007
fix: revert device change in classification tests
AlessandroPolidori Jun 1, 2023
f99ec92
fix: Classif. evaluation test now on cpu
AlessandroPolidori Jun 5, 2023
6de6a78
Merge remote-tracking branch 'origin/dev' into feature/gradcam_for_vits
AlessandroPolidori Jun 5, 2023
4c67015
fix: Cam objects now Optional and initialized as None
AlessandroPolidori Jun 5, 2023
db32198
fix: Moved common lines before if statement
AlessandroPolidori Jun 5, 2023
fc7d50f
fix: Pytorch model format has now same name with extension .pth
AlessandroPolidori Jun 5, 2023
b4b00a4
fix: Add empty line before return statements
AlessandroPolidori Jun 5, 2023
23ab687
fix: Add with torch.set_grad_enabled(self.gradcam) to evaluation test
AlessandroPolidori Jun 5, 2023
fe9f850
fix: Remove hardcoded input_shape
AlessandroPolidori Jun 5, 2023
2955b91
fix: Fix docstrings
AlessandroPolidori Jun 6, 2023
ab58ae1
fix: Fix classification documentation model.pth
AlessandroPolidori Jun 6, 2023
23ace40
fix: Narrow scikit classifiers' type in pytorch wrapper
AlessandroPolidori Jun 6, 2023
1b41803
add: Add citation in vit_explainability.py
AlessandroPolidori Jun 6, 2023
6da193d
fix: Fix model path in classification task test
AlessandroPolidori Jun 6, 2023
d12d9f4
add: Add check of classifier type in LinearModelPytorchWrapper
AlessandroPolidori Jun 6, 2023
2bd3f52
fix: Fix LinearModelPytorchWrapper example_input ValueError message
AlessandroPolidori Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions docs/tutorials/examples/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ print_config: true

model:
num_classes: ???
gradcam: true
module:
lr_scheduler_interval: "epoch"

task:
lr_multiplier: 0.0
gradcam: true
run_test: True
export_type: [torchscript]
export_type: [torchscript, pytorch]
report: True
output:
example: True
Expand Down Expand Up @@ -169,8 +169,9 @@ datamodule:
class_3: 2

task:
gradcam: True # Enable gradcam computation during evaluation
run_test: True # Perform test evaluation at the end of training
export_type: [torchscript]
export_type: [torchscript, pytorch]
report: True
output:
example: True # Generate an example of concordants and discordants predictions for each class
Expand All @@ -179,7 +180,6 @@ model:
num_classes: 3 # This is very important
module:
lr_scheduler_interval: "epoch"
gradcam: True # Enable gradcam computation during evaluation

backbone:
model:
Expand Down Expand Up @@ -258,8 +258,14 @@ core:
upload_artifacts: true
name: classification_evalutation_base

logger:
mlflow:
experiment_name: name_of_the_experiment
run_name: ${core.name}

task:
_target_: quadra.tasks.ClassificationEvaluation
gradcam: true
output:
example: true
model_path: ???
Expand All @@ -283,9 +289,6 @@ datamodule:
class_2: 1
class_3: 2

model:
num_classes: 3

core:
tag: "run"
upload_artifacts: true
Expand All @@ -294,7 +297,7 @@ core:
task:
output:
example: true
model_path: path/to/deployment_model
model_path: path/to/model.pth
```

Notice that we must provide the path to a deployment model file that will be used to perform inferences. In this case class_to_idx is mandatory (we can not infer it from a test-set). We suggest to be careful to set the same class_to_idx that has been used to train the model.
Expand Down
9 changes: 7 additions & 2 deletions docs/tutorials/examples/sklearn_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ backbone:
pretrained: true
freeze: true

task:
export_type: [torchscript, pytorch]

core:
tag: "run"
name: "sklearn-classification"
Expand All @@ -112,6 +115,7 @@ datamodule:
```

By default the experiment will use dino_vitb8 as backbone, resizing the images to 224x224 and training a logistic regression classifier. Setting the `n_splits` parameter to 1 will use a standard 70/30 train/validation split (given the parameters specified in the base datamodule) instead of cross validation.
It will also export the model in two formats, "torchscript" and "pytorch".

An actual configuration file based on the above could be this one (suppose it's saved under `configs/experiment/custom_experiment/sklearn_classification.yaml`):
```yaml
Expand Down Expand Up @@ -139,7 +143,7 @@ datamodule:

task:
device: cuda:0
export_type: [torchscript]
export_type: [torchscript, pytorch]
output:
folder: classification_experiment
save_backbone: true
Expand Down Expand Up @@ -223,14 +227,15 @@ datamodule:

task:
device: cuda:0
gradcam: true
output:
folder: classification_test_experiment
report: true
example: true
experiment_path:
```

This will test the model trained in the given experiment on the given dataset. The experiment results will be saved under the `classification_test_experiment` folder.
This will test the model trained in the given experiment on the given dataset. The experiment results will be saved under the `classification_test_experiment` folder. If gradcam is set to True, original and gradcam results will be saved during the generate_report().

### Run

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ print_config: true

model:
num_classes: ???
gradcam: true
module:
lr_scheduler_interval: "epoch"

task:
lr_multiplier: 0.0
run_test: True
export_type: [torchscript]
export_type: [torchscript, pytorch]
report: True
gradcam: True
output:
example: True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ logger:
run_name: ${core.name}

task:
gradcam: true
output:
example: true
model_path: ???
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ backbone:
pretrained: true
freeze: true

task:
export_type: [torchscript, pytorch]

core:
tag: "run"
name: "sklearn-classification"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ core:
tag: run
name: sklearn-classification-test

task:
device: cuda:0
gradcam: false
output:
folder: classification_experiment
report: true
example: true
experiment_path: ???

datamodule:
num_workers: 8
batch_size: 32
Expand Down
3 changes: 1 addition & 2 deletions quadra/configs/model/classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ classifier:
module:
_target_: quadra.modules.classification.ClassificationModule
lr_scheduler_interval: "epoch"
criterion: ${loss}
gradcam: true
criterion: ${loss}
2 changes: 1 addition & 1 deletion quadra/configs/task/classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: quadra.tasks.Classification
export_type: [torchscript]
export_type: [torchscript, pytorch]
lr_multiplier: null
output:
example: false
Expand Down
1 change: 1 addition & 0 deletions quadra/configs/task/classification_evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ device: cuda:0
output:
example: true
model_path: ???
report: true
2 changes: 1 addition & 1 deletion quadra/configs/task/sklearn_classification.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: quadra.tasks.SklearnClassification
device: "cuda:0"
export_type: [torchscript]
export_type: [torchscript, pytorch]
output:
folder: "classification_experiment"
save_backbone: false
Expand Down
3 changes: 2 additions & 1 deletion quadra/configs/task/sklearn_classification_test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
_target_: quadra.tasks.classification.SklearnTestClassification
device: cuda:0
gradcam: false
output:
folder: classification_experiment
report: true
example: true
experiment_path:
experiment_path: ???
78 changes: 61 additions & 17 deletions quadra/modules/classification/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import Any, List, Optional, Tuple, Union, cast

import numpy as np
import timm
import torch
import torchmetrics
import torchmetrics.functional as TMF
from pytorch_grad_cam import GradCAM
from scipy import ndimage
from torch import nn, optim

from quadra.models.classification import BaseNetworkBuilder
from quadra.modules.base import BaseLightningModule
from quadra.utils.models import is_vision_transformer
from quadra.utils.utils import get_logger
from quadra.utils.vit_explainability import VitAttentionGradRollout

log = get_logger(__name__)

Expand Down Expand Up @@ -45,11 +49,14 @@ def __init__(
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
self.cam: GradCAM
self.cam: Optional[GradCAM] = None
self.grad_rollout: Optional[VitAttentionGradRollout] = None

if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and not is_vision_transformer(
cast(BaseNetworkBuilder, self.model).features_extractor
):
log.warning(
"Backbone must be compatible with gradcam, at the moment only ResNets supported, disabling gradcam"
"Backbone not compatible with gradcam. Only timm ResNets, timm ViTs and TorchHub dinoViTs supported",
)
self.gradcam = False

Expand Down Expand Up @@ -124,25 +131,55 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
prog_bar=False,
)

def on_predict_start(self) -> None:
"""If gradcam will be computed, saves all requires_grad values and set them to True before the predict."""
if self.gradcam:
def prepare_gradcam(self) -> None:
"""Instantiate gradcam handlers."""
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]] # type: ignore[index]
self.cam = GradCAM(model=self.model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
self.cam = GradCAM(
model=self.model,
target_layers=target_layers,
use_cuda=torch.cuda.is_available(),
)
# Activating gradients
for p in self.model.features_extractor.layer4[-1].parameters():
p.requires_grad = True
elif is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor):
self.grad_rollout = VitAttentionGradRollout(self.model)
AlessandroPolidori marked this conversation as resolved.
Show resolved Hide resolved
else:
log.warning("Gradcam not implemented for this backbone, it won't be computed")
self.original_requires_grads.clear()
self.gradcam = False

def on_predict_start(self) -> None:
"""If gradcam, prepares gradcam and saves params requires_grad state."""
if self.gradcam:
# Saving params requires_grad state
for p in self.model.parameters():
self.original_requires_grads.append(p.requires_grad)
p.requires_grad = True
self.prepare_gradcam()

return super().on_predict_start()

def on_predict_end(self) -> None:
"""If we computed gradcam, requires_grad values are reset to original value."""
if self.gradcam:
# Get back to initial state
for i, p in enumerate(self.model.parameters()):
p.requires_grad = self.original_requires_grads[i]

self.cam.activations_and_grads.release()
# We are using GradCAM package only for resnets at the moment
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam is not None:
# Needed to solve jitting bug
self.cam.activations_and_grads.release()
elif (
is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor)
and self.grad_rollout is not None
):
for handle in self.grad_rollout.f_hook_handles:
handle.remove()
for handle in self.grad_rollout.b_hook_handles:
handle.remove()

return super().on_predict_end()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
Expand All @@ -158,18 +195,25 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
grayscale_cam: gray scale gradcams
"""
im, _ = batch
# inference_mode set to false because gradcam needs gradients
outputs = self(im)
probs = torch.softmax(outputs, dim=1)
predicted_classes = torch.max(probs, dim=1).indices.tolist()
if self.gradcam:
# inference_mode set to false because gradcam needs gradients
with torch.inference_mode(False):
im = im.clone()
outputs = self(im)
probs = torch.softmax(outputs, dim=1)
predicted_classes = torch.max(probs, dim=1).indices.tolist()
grayscale_cam = self.cam(input_tensor=im, targets=None)

if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam:
grayscale_cam = self.cam(input_tensor=im, targets=None)
elif (
is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor) and self.grad_rollout
):
grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=predicted_classes)
orig_shape = grayscale_cam_low_res.shape
new_shape = (orig_shape[0], im.shape[2], im.shape[3])
zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
else:
outputs = self(im)
probs = torch.softmax(outputs, dim=1)
predicted_classes = torch.max(probs, dim=1).indices.tolist()
grayscale_cam = None
return predicted_classes, grayscale_cam

Expand Down
2 changes: 1 addition & 1 deletion quadra/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Task(Generic[DataModuleT]):
"""Base Experiment Task.

Args:
config: The experiment configuration
config: The experiment configuration.
export_type: List of export method for the model, e.g. [torchscript]. Defaults to None.
"""

Expand Down
Loading