Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/evaluation task #22

Merged
merged 19 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
e115fd3
refactor: Anomaly and Segmentation evaluations now extend Evaluation …
AlessandroPolidori Jul 5, 2023
7a16e87
Add TODO: segmentation evaluation still not tested
AlessandroPolidori Jul 5, 2023
8d1b68b
refactor: Now classification evaluation extends Evaluation
AlessandroPolidori Jul 5, 2023
3b6314f
refactor: Now all evaluation tasks extend Evaluation
AlessandroPolidori Jul 6, 2023
4c5c0b7
fix: Call datamodule.prepare_data() in segm. evaluation, add seg eval…
AlessandroPolidori Jul 6, 2023
886a170
fix: Add back trainer to smp_multiclass experiment
AlessandroPolidori Jul 6, 2023
6d613ea
fix: Delete trainer in smp_multiclass_eval experiment config
AlessandroPolidori Jul 7, 2023
6824453
style: Add type hint
AlessandroPolidori Jul 7, 2023
930ccde
fix: Device arg is optional
AlessandroPolidori Jul 7, 2023
b730dcb
todo: Add todo for the prepare_data() manual call
AlessandroPolidori Jul 7, 2023
be35de4
style: Delete Todos
AlessandroPolidori Jul 7, 2023
99be535
style: Move self.deployment_model = self.model_path to parent class
AlessandroPolidori Jul 7, 2023
57db12e
fix: Model import in sklearn evaluations now supports .pt and .pth mo…
AlessandroPolidori Jul 10, 2023
ef8e089
fix: Swap to correct order input height and width read from model json
AlessandroPolidori Jul 11, 2023
66054aa
style: Move line of code
AlessandroPolidori Jul 11, 2023
4658785
add: Propagating mean and std in segmentation prepare
AlessandroPolidori Jul 11, 2023
dec357f
add: small changes
AlessandroPolidori Jul 11, 2023
083dbca
fix: Add **kwargs to sklearnTestClassification
AlessandroPolidori Jul 11, 2023
94e096f
fix: Fix config and docs for sklearn test classif and patch
AlessandroPolidori Jul 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ core:
tag: "run"
name: "sklearn-classification-patch-test"

task:
model_path: ???

datamodule:
num_workers: 8
batch_size: 32
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ task:
folder: classification_experiment
report: true
example: true
experiment_path: ???
model_path: ???

datamodule:
num_workers: 8
batch_size: 32
Expand Down
17 changes: 17 additions & 0 deletions quadra/configs/experiment/base/segmentation/smp_evaluation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_

defaults:
- override /datamodule: base/segmentation
- override /transforms: default_resize
- override /task: segmentation_evaluation
core:
tag: "run"
name: "quadra_default"
upload_artifacts: True

task:
model_path: ???

datamodule:
num_workers: 5
batch_size: 32
10 changes: 5 additions & 5 deletions quadra/configs/experiment/base/segmentation/smp_multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ defaults:
- override /backbone: smp
- override /trainer: lightning_gpu

core:
tag: "run"
name: "quadra_default"
upload_artifacts: True

trainer:
devices: [0]
max_epochs: 100
num_sanity_val_steps: 0

core:
tag: "run"
name: "quadra_default"
upload_artifacts: True

datamodule:
num_workers: 8
batch_size: 32
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_

defaults:
- override /datamodule: base/segmentation_multiclass
- override /transforms: default_resize
- override /task: segmentation_evaluation

core:
tag: "run"
name: "quadra_default"
upload_artifacts: True

task:
model_path: ???

datamodule:
num_workers: 8
batch_size: 32
3 changes: 3 additions & 0 deletions quadra/configs/task/segmentation_evaluation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: quadra.tasks.SegmentationAnalysisEvaluation
device: cuda:0
model_path: ???
2 changes: 1 addition & 1 deletion quadra/configs/task/sklearn_classification_patch_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ output:
report: true
example: true
reconstruction_method: major_voting
experiment_path:
model_path: ???
2 changes: 1 addition & 1 deletion quadra/configs/task/sklearn_classification_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ output:
folder: classification_experiment
report: true
example: true
experiment_path: ???
model_path: ???
69 changes: 12 additions & 57 deletions quadra/tasks/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import json
import os
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union, cast
from typing import Dict, Generic, List, Optional, TypeVar, Union, cast

import cv2
import hydra
Expand All @@ -21,10 +20,10 @@

from quadra.callbacks.mlflow import get_mlflow_logger
from quadra.datamodules import AnomalyDataModule
from quadra.tasks.base import LightningTask, Task
from quadra.tasks.base import Evaluation, LightningTask
from quadra.utils import utils
from quadra.utils.classification import get_results
from quadra.utils.export import export_torchscript_model, import_deployment_model
from quadra.utils.export import export_torchscript_model

log = utils.get_logger(__name__)

Expand Down Expand Up @@ -287,7 +286,7 @@ def _upload_artifacts(self):
utils.upload_file_tensorboard(a, tensorboard_logger)


class AnomalibEvaluation(Task[AnomalyDataModule]):
class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
"""Evaluation task for Anomalib.

Args:
Expand All @@ -301,63 +300,19 @@ class AnomalibEvaluation(Task[AnomalyDataModule]):
def __init__(
self, config: DictConfig, model_path: str, use_training_threshold: bool = False, device: Optional[str] = None
):
super().__init__(config=config)
self.model_data: Dict[str, Any]
self._deployment_model: Any
self.deployment_model_type: str
self.model_path = model_path
self.output_path = ""

if device is None:
self.device = utils.get_device()
else:
self.device = device
super().__init__(config=config, model_path=model_path, device=device)

self.config = config
self.report_path = ""
self.metadata = {"report_files": []}
self.model_info_filename = "model.json"
self.use_training_threshold = use_training_threshold

@property
def deployment_model(self):
"""Deployment model."""
return self._deployment_model

@deployment_model.setter
def deployment_model(self, model_path: str):
"""Set the deployment model."""
self._deployment_model, self.deployment_model_type = import_deployment_model(model_path, self.device)

def prepare(self) -> None:
"""Prepare the evaluation."""
with open(os.path.join(Path(self.model_path).parent, self.model_info_filename)) as f:
self.model_data = json.load(f)

if not isinstance(self.model_data, dict):
raise ValueError("Model info file is not a valid json")

if self.model_data["input_size"][0] != self.config.transforms.input_height:
log.warning(
f"Input height of the model ({self.model_data['input_size'][0]}) is different from the one specified "
+ f"in the config ({self.config.transforms.input_height}). Fixing the config."
)
self.config.transforms.input_height = self.model_data["input_size"][0]

if self.model_data["input_size"][1] != self.config.transforms.input_width:
log.warning(
f"Input width of the model ({self.model_data['input_size'][1]}) is different from the one specified "
+ f"in the config ({self.config.transforms.input_width}). Fixing the config."
)
self.config.transforms.input_width = self.model_data["input_size"][1]

self.deployment_model = self.model_path

super().prepare()
self.datamodule = self.config.datamodule

def test(self) -> None:
"""Perform test."""
log.info("Running test")
# prepare_data() must be explicitly called because there is no lightning training
self.datamodule.prepare_data()
self.datamodule.setup(stage="test")
test_dataloader = self.datamodule.test_dataloader()
Expand Down Expand Up @@ -422,8 +377,8 @@ def generate_report(self) -> None:
if len(self.report_path) > 0:
os.makedirs(self.report_path, exist_ok=True)

os.makedirs(os.path.join(self.output_path, "predictions"), exist_ok=True)
os.makedirs(os.path.join(self.output_path, "heatmaps"), exist_ok=True)
os.makedirs(os.path.join(self.report_path, "predictions"), exist_ok=True)
os.makedirs(os.path.join(self.report_path, "heatmaps"), exist_ok=True)

anomaly_scores = self.metadata["anomaly_scores"].cpu().numpy()
good_scores = anomaly_scores[np.where(np.array(self.metadata["image_labels"]) == 0)]
Expand Down Expand Up @@ -517,7 +472,7 @@ def generate_report(self) -> None:

output_mask = output_mask * 255
output_mask = cv2.resize(output_mask, (img.shape[1], img.shape[0]))
cv2.imwrite(os.path.join(self.output_path, "predictions", output_mask_name), output_mask)
cv2.imwrite(os.path.join(self.report_path, "predictions", output_mask_name), output_mask)

# Normalize the heatmaps based on the current min and max anomaly score, otherwise even on good images the
# anomaly map looks like there are defects while it's not true
Expand All @@ -529,7 +484,7 @@ def generate_report(self) -> None:
output_heatmap = anomaly_map_to_color_map(output_heatmap, normalize=False)
output_heatmap = cv2.resize(output_heatmap, (img.shape[1], img.shape[0]))
cv2.imwrite(
os.path.join(self.output_path, "heatmaps", output_mask_name),
os.path.join(self.report_path, "heatmaps", output_mask_name),
cv2.cvtColor(output_heatmap, cv2.COLOR_RGB2BGR),
)

Expand All @@ -541,7 +496,7 @@ def generate_report(self) -> None:
],
}

with open(os.path.join(self.output_path, "anomaly_test_output.json"), "w") as f:
with open(os.path.join(self.report_path, "anomaly_test_output.json"), "w") as f:
json.dump(json_output, f)

def execute(self) -> None:
Expand Down
61 changes: 44 additions & 17 deletions quadra/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

import hydra
Expand All @@ -9,6 +11,7 @@
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities.device_parser import parse_gpu_ids
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.jit._script import RecursiveScriptModule

from quadra import get_version
Expand Down Expand Up @@ -290,46 +293,70 @@ def execute(self) -> None:
log.info("If you are reading this, it means that library is installed correctly!")


class Evaluation(Task):
class Evaluation(Generic[DataModuleT], Task[DataModuleT]):
"""Base Evaluation Task with deployment models.

Args:
config: The experiment configuration
model_path: The model path.
report_folder: The report folder. Defaults to None.
device: Device to use for evaluation. If None, the device is automatically determined.

Raises:
ValueError: If the experiment path is not provided
"""

def __init__(
self,
config: DictConfig,
model_path: str,
report_folder: Optional[str] = None,
device: Optional[str] = None,
):
super().__init__(config=config)

if device is None:
self.device = utils.get_device()
else:
self.device = device

self.config = config
self.metadata = {"report_files": []}
self.model_data: Dict[str, Any]
self.model_path = model_path
self.device = utils.get_device()
self.report_folder = report_folder
self._deployment_model: RecursiveScriptModule
self.deployment_model_type: str
if self.report_folder is None:
log.warning("Report folder is not provided, using default report folder")
self.report_folder = "report"
self.model_info_filename = "model.json"
self.report_path = ""
self.metadata = {"report_files": []}

@property
def deployment_model(self) -> RecursiveScriptModule:
"""RecursiveScriptModule: The deployment model."""
def deployment_model(self) -> Union[RecursiveScriptModule, nn.Module]:
"""Deployment model."""
return self._deployment_model

@deployment_model.setter
def deployment_model(self, model: RecursiveScriptModule) -> None:
"""RecursiveScriptModule: The deployment model."""
self._deployment_model = model
def deployment_model(self, model_path: str):
"""Set the deployment model."""
self._deployment_model, self.deployment_model_type = import_deployment_model( # type: ignore[assignment]
model_path, self.device
)

def prepare(self) -> None:
"""Prepare the evaluation."""
self.deployment_model, self.deployment_model_type = import_deployment_model(self.model_path, self.device)
with open(os.path.join(Path(self.model_path).parent, self.model_info_filename)) as f:
self.model_data = json.load(f)

if not isinstance(self.model_data, dict):
raise ValueError("Model info file is not a valid json")

if self.model_data["input_size"][0] != self.config.transforms.input_height:
log.warning(
f"Input height of the model ({self.model_data['input_size'][0]}) is different from the one specified "
+ f"in the config ({self.config.transforms.input_height}). Fixing the config."
)
self.config.transforms.input_height = self.model_data["input_size"][0]

if self.model_data["input_size"][1] != self.config.transforms.input_width:
log.warning(
f"Input width of the model ({self.model_data['input_size'][1]}) is different from the one specified "
+ f"in the config ({self.config.transforms.input_width}). Fixing the config."
)
self.config.transforms.input_width = self.model_data["input_size"][1]

self.deployment_model = self.model_path # type: ignore[assignment]
Loading