diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3421648b85c..826f1824104 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -176,10 +176,12 @@ class and used with 'ce' loss num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, mdmc_average="global", + average="micro", ), MulticlassJaccardIndex( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, + average="micro", ), ], prefix="train_",