Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setr logging fix #59

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 39 additions & 116 deletions minerva/models/nets/setr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import lightning as L
import torch
from torch import nn
from torchmetrics import JaccardIndex
from torchmetrics import JaccardIndex, Metric

from minerva.models.nets.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample, resize
Expand Down Expand Up @@ -407,12 +407,9 @@ def __init__(
conv_act: Optional[nn.Module] = None,
interpolate_mode: str = "bilinear",
loss_fn: Optional[nn.Module] = None,
log_train_metrics: bool = False,
train_metrics: Optional[nn.Module] = None,
log_val_metrics: bool = False,
val_metrics: Optional[nn.Module] = None,
log_test_metrics: bool = False,
test_metrics: Optional[nn.Module] = None,
train_metrics: Optional[Dict[str, Metric]] = None,
val_metrics: Optional[Dict[str, Metric]] = None,
test_metrics: Optional[Dict[str, Metric]] = None,
aux_output: bool = True,
aux_output_layers: list[int] | None = [9, 14, 19],
aux_weights: list[float] = [0.3, 0.3, 0.3],
Expand Down Expand Up @@ -460,10 +457,18 @@ def __init__(
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
loss_fn : nn.Module, optional
The loss function to be used during training. Defaults to None.
log_metrics : bool
Whether to log metrics during training. Defaults to True.
metrics : list[MetricTypeSetR], optional
The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
train_metrics : Dict[str, Metric], optional
The metrics to be used for training evaluation. Defaults to None.
val_metrics : Dict[str, Metric], optional
The metrics to be used for validation evaluation. Defaults to None.
test_metrics : Dict[str, Metric], optional
The metrics to be used for testing evaluation. Defaults to None.
aux_output : bool
Whether to include auxiliary output heads in the model. Defaults to True.
aux_output_layers : list[int] | None
The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
aux_weights : list[float]
The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].

"""
super().__init__()
Expand All @@ -486,27 +491,11 @@ def __init__(
self.num_classes = num_classes
self.aux_weights = aux_weights

self.log_train_metrics = log_train_metrics
self.log_val_metrics = log_val_metrics
self.log_test_metrics = log_test_metrics

if log_train_metrics:
assert (
train_metrics is not None
), "train_metrics must be provided if log_train_metrics is True"
self.train_metrics = train_metrics

if log_val_metrics:
assert (
val_metrics is not None
), "val_metrics must be provided if log_val_metrics is True"
self.val_metrics = val_metrics

if log_test_metrics:
assert (
test_metrics is not None
), "test_metrics must be provided if log_test_metrics is True"
self.test_metrics = test_metrics
self.metrics = {
"train": train_metrics,
"val": val_metrics,
"test": test_metrics,
}

self.model = _SetR_PUP(
image_size=image_size,
Expand All @@ -531,18 +520,20 @@ def __init__(
aux_output_layers=aux_output_layers,
)

self.train_step_outputs = []
self.train_step_labels = []

self.val_step_outputs = []
self.val_step_labels = []

self.test_step_outputs = []
self.test_step_labels = []

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str):
if self.metrics[step_name] is None:
return {}

return {
f"{step_name}_{metric_name}": metric.to(self.device)(
torch.argmax(y_hat, dim=1, keepdim=True), y
)
for metric_name, metric in self.metrics[step_name].items()
}

def _loss_func(
self,
y_hat: (
Expand Down Expand Up @@ -577,6 +568,7 @@ def _loss_func(
+ (loss_aux3 * self.aux_weights[2])
)
loss = self.loss_fn(y_hat, y.long())

return loss

def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
Expand All @@ -600,86 +592,17 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
y_hat = self.model(x.float())
loss = self._loss_func(y_hat[0], y.squeeze(1))

if step_name == "train":
self.train_step_outputs.append(y_hat[0])
self.train_step_labels.append(y)
elif step_name == "val":
self.val_step_outputs.append(y_hat[0])
self.val_step_labels.append(y)
elif step_name == "test":
self.test_step_outputs.append(y_hat[0])
self.test_step_labels.append(y)

self.log_dict(
{
f"{step_name}_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def on_train_epoch_end(self):
if self.log_train_metrics:
y_hat = torch.cat(self.train_step_outputs)
y = torch.cat(self.train_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.train_metrics(preds, y)
mIoU = self.train_metrics.compute()

metrics = self._compute_metrics(y_hat[0], y, step_name)
for metric_name, metric_value in metrics.items():
self.log_dict(
{
f"train_metrics": mIoU,
},
{metric_name: metric_value},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.train_step_outputs.clear()
self.train_step_labels.clear()

def on_validation_epoch_end(self):
if self.log_val_metrics:
y_hat = torch.cat(self.val_step_outputs)
y = torch.cat(self.val_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.val_metrics(preds, y)
mIoU = self.val_metrics.compute()

self.log_dict(
{
f"val_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.val_step_outputs.clear()
self.val_step_labels.clear()

def on_test_epoch_end(self):
if self.log_test_metrics:
y_hat = torch.cat(self.test_step_outputs)
y = torch.cat(self.test_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.test_metrics(preds, y)
mIoU = self.test_metrics.compute()
self.log_dict(
{
f"test_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.test_step_outputs.clear()
self.test_step_labels.clear()
return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, "train")
Expand All @@ -694,7 +617,7 @@ def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None
):
x, _ = batch
return self.model(x)
return self.model(x)[0]

def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
2 changes: 1 addition & 1 deletion tests/models/nets/test_setr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_setr_predict():
preds = model.predict_step((x, mask), 0)
assert preds is not None
assert (
preds[0].shape == mask_shape
preds.shape == mask_shape
), f"Expected shape {mask_shape}, but got {preds[0].shape}"


Expand Down
Loading