Skip to content

Commit

Permalink
logging fix ready
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielBG0 committed May 14, 2024
1 parent 6d7ec8c commit b0b3b2b
Showing 1 changed file with 31 additions and 37 deletions.
68 changes: 31 additions & 37 deletions minerva/models/nets/setr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple, Dict
from typing import Dict, Optional, Tuple

import lightning as L
import torch
Expand Down Expand Up @@ -407,9 +407,9 @@ def __init__(
conv_act: Optional[nn.Module] = None,
interpolate_mode: str = "bilinear",
loss_fn: Optional[nn.Module] = None,
train_metrics: Optional[Metric] = None,
val_metrics: Optional[Metric] = None,
test_metrics: Optional[Metric] = None,
train_metrics: Optional[Dict[str, Metric]] = None,
val_metrics: Optional[Dict[str, Metric]] = None,
test_metrics: Optional[Dict[str, Metric]] = None,
aux_output: bool = True,
aux_output_layers: list[int] | None = [9, 14, 19],
aux_weights: list[float] = [0.3, 0.3, 0.3],
Expand Down Expand Up @@ -483,17 +483,11 @@ def __init__(
self.num_classes = num_classes
self.aux_weights = aux_weights

if train_metrics is not None:
self.train_metrics = train_metrics
self.log_train_metrics = True

if val_metrics is not None:
self.val_metrics = val_metrics
self.log_val_metrics = True

if test_metrics is not None:
self.test_metrics = test_metrics
self.log_test_metrics = True
self.metrics = {
"train": train_metrics,
"val": val_metrics,
"test": test_metrics,
}

self.model = _SetR_PUP(
image_size=image_size,
Expand All @@ -520,8 +514,17 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor) ->

def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str):
if self.metrics[step_name] is None:
return {}

return {
f"{step_name}_{metric_name}": metric.to(self.device)(
torch.argmax(y_hat, dim=1, keepdim=True), y
)
for metric_name, metric in self.metrics[step_name].items()
}

def _loss_func(
self,
Expand Down Expand Up @@ -557,6 +560,7 @@ def _loss_func(
+ (loss_aux3 * self.aux_weights[2])
)
loss = self.loss_fn(y_hat, y.long())

return loss

def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
Expand All @@ -580,25 +584,15 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
y_hat = self.model(x.float())
loss = self._loss_func(y_hat[0], y.squeeze(1))

if step_name == "train":
self.train_step_outputs.append(y_hat[0])
self.train_step_labels.append(y)
elif step_name == "val":
self.val_step_outputs.append(y_hat[0])
self.val_step_labels.append(y)
elif step_name == "test":
self.test_step_outputs.append(y_hat[0])
self.test_step_labels.append(y)

self.log_dict(
{
f"{step_name}_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
metrics = self._compute_metrics(y_hat[0], y, step_name)
for metric_name, metric_value in metrics.items():
self.log_dict(
{metric_name: metric_value},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

Expand All @@ -615,7 +609,7 @@ def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None
):
x, _ = batch
return self.model(x)
return self.model(x)[0]

def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-3)

0 comments on commit b0b3b2b

Please sign in to comment.