Skip to content

Commit

Permalink
dev: log confusion matrices at each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Feb 7, 2024
1 parent e1ccc7c commit 86fc605
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
15 changes: 14 additions & 1 deletion myria3d/callbacks/comet_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_comet_logger(trainer: Trainer) -> Optional[CometLogger]:
return logger

warnings.warn(
"You are using comet related callback, but CometLogger was not found for some reason...",
"You are using comet related functions, but trainer has no CometLogger among its loggers.",
UserWarning,
)
return None
Expand Down Expand Up @@ -71,3 +71,16 @@ def setup(self, trainer, pl_module, stage):
log_path = os.getcwd()
log.info(f"----------------\n LOGS DIR is {log_path}\n ----------------")
logger.experiment.log_parameter("experiment_logs_dirpath", log_path)


def log_comet_cm(lightning_module, confmat, phase):
logger = get_comet_logger(trainer=lightning_module)
if logger:
labels = list(lightning_module.hparams.classification_dict.values())
logger.experiment.log_confusion_matrix(
matrix=confmat.cpu().numpy().tolist(),
labels=labels,
file_name=f"{phase}-confusion-matrix",
title="{phase} confusion matrix",
epoch=lightning_module.current_epoch,
)
6 changes: 5 additions & 1 deletion myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.data import Batch
from torch_geometric.nn import knn_interpolate
from torchmetrics.classification import MulticlassJaccardIndex
from myria3d.callbacks.comet_callbacks import log_comet_cm

from myria3d.metrics.iou import iou
from myria3d.models.modules.pyg_randla_net import PyGRandLANet
Expand Down Expand Up @@ -139,7 +140,7 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict:
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False)

with torch.no_grad():
preds = torch.argmax(logits.detach(), dim=1)
self.train_iou(preds, targets)
Expand All @@ -150,6 +151,7 @@ def on_train_epoch_end(self) -> None:
iou_epoch = self.train_iou.compute()
self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.train_iou.confmat, "train")
log_comet_cm(self, self.train_iou.confmat, "train")
self.train_iou.reset()

def validation_step(self, batch: Batch, batch_idx: int) -> dict:
Expand Down Expand Up @@ -187,6 +189,7 @@ def on_validation_epoch_end(self) -> None:
iou_epoch = self.val_iou.compute()
self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.val_iou.confmat, "val")
log_comet_cm(self, self.val_iou.confmat, "val")
self.val_iou.reset()

def test_step(self, batch: Batch, batch_idx: int):
Expand Down Expand Up @@ -221,6 +224,7 @@ def on_test_epoch_end(self) -> None:
iou_epoch = self.test_iou.compute()
self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.test_iou.confmat, "test")
log_comet_cm(self, self.test_iou.confmat, "test")
self.test_iou.reset()

def predict_step(self, batch: Batch) -> dict:
Expand Down

0 comments on commit 86fc605

Please sign in to comment.