Skip to content

Commit

Permalink
Merge pull request #59 from discovery-unicamp/setr-logging-fix
Browse files Browse the repository at this point in the history
Setr logging fix
  • Loading branch information
GabrielBG0 authored May 14, 2024
2 parents cb7de07 + 9424c0b commit 94bf3f4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 117 deletions.
155 changes: 39 additions & 116 deletions minerva/models/nets/setr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import lightning as L
import torch
from torch import nn
from torchmetrics import JaccardIndex
from torchmetrics import JaccardIndex, Metric

from minerva.models.nets.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample, resize
Expand Down Expand Up @@ -407,12 +407,9 @@ def __init__(
conv_act: Optional[nn.Module] = None,
interpolate_mode: str = "bilinear",
loss_fn: Optional[nn.Module] = None,
log_train_metrics: bool = False,
train_metrics: Optional[nn.Module] = None,
log_val_metrics: bool = False,
val_metrics: Optional[nn.Module] = None,
log_test_metrics: bool = False,
test_metrics: Optional[nn.Module] = None,
train_metrics: Optional[Dict[str, Metric]] = None,
val_metrics: Optional[Dict[str, Metric]] = None,
test_metrics: Optional[Dict[str, Metric]] = None,
aux_output: bool = True,
aux_output_layers: list[int] | None = [9, 14, 19],
aux_weights: list[float] = [0.3, 0.3, 0.3],
Expand Down Expand Up @@ -460,10 +457,18 @@ def __init__(
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
loss_fn : nn.Module, optional
The loss function to be used during training. Defaults to None.
log_metrics : bool
Whether to log metrics during training. Defaults to True.
metrics : list[MetricTypeSetR], optional
The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
train_metrics : Dict[str, Metric], optional
The metrics to be used for training evaluation. Defaults to None.
val_metrics : Dict[str, Metric], optional
The metrics to be used for validation evaluation. Defaults to None.
test_metrics : Dict[str, Metric], optional
The metrics to be used for testing evaluation. Defaults to None.
aux_output : bool
Whether to include auxiliary output heads in the model. Defaults to True.
aux_output_layers : list[int] | None
The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
aux_weights : list[float]
The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].
"""
super().__init__()
Expand All @@ -486,27 +491,11 @@ def __init__(
self.num_classes = num_classes
self.aux_weights = aux_weights

self.log_train_metrics = log_train_metrics
self.log_val_metrics = log_val_metrics
self.log_test_metrics = log_test_metrics

if log_train_metrics:
assert (
train_metrics is not None
), "train_metrics must be provided if log_train_metrics is True"
self.train_metrics = train_metrics

if log_val_metrics:
assert (
val_metrics is not None
), "val_metrics must be provided if log_val_metrics is True"
self.val_metrics = val_metrics

if log_test_metrics:
assert (
test_metrics is not None
), "test_metrics must be provided if log_test_metrics is True"
self.test_metrics = test_metrics
self.metrics = {
"train": train_metrics,
"val": val_metrics,
"test": test_metrics,
}

self.model = _SetR_PUP(
image_size=image_size,
Expand All @@ -531,18 +520,20 @@ def __init__(
aux_output_layers=aux_output_layers,
)

self.train_step_outputs = []
self.train_step_labels = []

self.val_step_outputs = []
self.val_step_labels = []

self.test_step_outputs = []
self.test_step_labels = []

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str):
if self.metrics[step_name] is None:
return {}

return {
f"{step_name}_{metric_name}": metric.to(self.device)(
torch.argmax(y_hat, dim=1, keepdim=True), y
)
for metric_name, metric in self.metrics[step_name].items()
}

def _loss_func(
self,
y_hat: (
Expand Down Expand Up @@ -577,6 +568,7 @@ def _loss_func(
+ (loss_aux3 * self.aux_weights[2])
)
loss = self.loss_fn(y_hat, y.long())

return loss

def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
Expand All @@ -600,86 +592,17 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
y_hat = self.model(x.float())
loss = self._loss_func(y_hat[0], y.squeeze(1))

if step_name == "train":
self.train_step_outputs.append(y_hat[0])
self.train_step_labels.append(y)
elif step_name == "val":
self.val_step_outputs.append(y_hat[0])
self.val_step_labels.append(y)
elif step_name == "test":
self.test_step_outputs.append(y_hat[0])
self.test_step_labels.append(y)

self.log_dict(
{
f"{step_name}_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def on_train_epoch_end(self):
if self.log_train_metrics:
y_hat = torch.cat(self.train_step_outputs)
y = torch.cat(self.train_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.train_metrics(preds, y)
mIoU = self.train_metrics.compute()

metrics = self._compute_metrics(y_hat[0], y, step_name)
for metric_name, metric_value in metrics.items():
self.log_dict(
{
f"train_metrics": mIoU,
},
{metric_name: metric_value},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.train_step_outputs.clear()
self.train_step_labels.clear()

def on_validation_epoch_end(self):
if self.log_val_metrics:
y_hat = torch.cat(self.val_step_outputs)
y = torch.cat(self.val_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.val_metrics(preds, y)
mIoU = self.val_metrics.compute()

self.log_dict(
{
f"val_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.val_step_outputs.clear()
self.val_step_labels.clear()

def on_test_epoch_end(self):
if self.log_test_metrics:
y_hat = torch.cat(self.test_step_outputs)
y = torch.cat(self.test_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.test_metrics(preds, y)
mIoU = self.test_metrics.compute()
self.log_dict(
{
f"test_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.test_step_outputs.clear()
self.test_step_labels.clear()
return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, "train")
Expand All @@ -694,7 +617,7 @@ def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None
):
x, _ = batch
return self.model(x)
return self.model(x)[0]

def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
2 changes: 1 addition & 1 deletion tests/models/nets/test_setr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_setr_predict():
preds = model.predict_step((x, mask), 0)
assert preds is not None
assert (
preds[0].shape == mask_shape
preds.shape == mask_shape
), f"Expected shape {mask_shape}, but got {preds[0].shape}"


Expand Down

0 comments on commit 94bf3f4

Please sign in to comment.