From 913c1e36f7ef31fb37a46a483132c0534bb4c3c6 Mon Sep 17 00:00:00 2001 From: Joshua Blackburn Date: Mon, 27 Jul 2020 08:28:50 -0400 Subject: [PATCH] Add argument comments --- pytorch_lightning/metrics/classification.py | 1 + pytorch_lightning/metrics/functional/classification.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index 01b347f2292ab..3cd802b97ec53 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -103,6 +103,7 @@ def __init__( normalize: whether to compute a normalized confusion matrix reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + num_classes: number of classes if known. Important for DDP reduction. """ super().__init__(name='confusion_matrix', reduce_group=reduce_group, diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8ce7b08340636..d3d67ef9fc876 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -245,6 +245,7 @@ def confusion_matrix( pred: estimated targets target: ground truth labels normalize: normalizes confusion matrix + num_classes: number of classes if known Return: Tensor, confusion matrix C [num_classes, num_classes ]