From 70668856ecfbdfaf98cdc8d2f303e1763a293260 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 12 May 2023 12:28:17 -0500 Subject: [PATCH 1/2] SemanticSegmentationTask: fix ignore_index weighting --- torchgeo/trainers/segmentation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3421648b85c..acc58fd7771 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -176,10 +176,12 @@ class and used with 'ce' loss num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, mdmc_average="global", + average="weighted", ), MulticlassJaccardIndex( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, + average="weighted", ), ], prefix="train_", From cf658bcf09b8429a7cbdee807bb831454a1aaa05 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 12 May 2023 13:41:03 -0500 Subject: [PATCH 2/2] weighted -> micro --- torchgeo/trainers/segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index acc58fd7771..826f1824104 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -176,12 +176,12 @@ class and used with 'ce' loss num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, mdmc_average="global", - average="weighted", + average="micro", ), MulticlassJaccardIndex( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, - average="weighted", + average="micro", ), ], prefix="train_",