diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index a189992227295..1ed645f86909c 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -30,9 +30,9 @@ precision, precision_recall_curve, recall, - roc + roc, ) -from pytorch_lightning.metrics.metric import TensorCollectionMetric, TensorMetric +from pytorch_lightning.metrics.metric import TensorMetric class Accuracy(TensorMetric): @@ -44,17 +44,16 @@ class Accuracy(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = Accuracy() - >>> metric(pred, target) - tensor(0.7500) + >>> metric(pred, target).item() + 0.75 """ def __init__( - self, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -65,11 +64,8 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='accuracy', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__(name="accuracy", reduce_group=reduce_group) self.num_classes = num_classes self.reduction = reduction @@ -84,8 +80,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return accuracy(pred=pred, target=target, - num_classes=self.num_classes, reduction=self.reduction) + return accuracy(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) class ConfusionMatrix(TensorMetric): @@ -106,22 +101,21 @@ class ConfusionMatrix(TensorMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - normalize: bool = False, - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + normalize: bool = False, + reduce_group: Any = None, ): """ Args: num_classes: number of classes normalize: whether to compute a normalized confusion matrix reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='confusion_matrix', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="confusion_matrix", + reduce_group=reduce_group, + ) self.normalize = normalize self.num_classes = num_classes @@ -140,8 +134,16 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: normalize=self.normalize, num_classes=self.num_classes) + def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: + """Aggregates results by stacking them instead of concatenating before averaging. + + Returns: + the aggregated results + """ + return torch.stack(tensors).mean(0) + -class PrecisionRecallCurve(TensorCollectionMetric): +class PrecisionRecallCurve(TensorMetric): """ Computes the precision recall curve @@ -161,28 +163,27 @@ class PrecisionRecallCurve(TensorCollectionMetric): """ def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - reduce_op: Any = None, + self, + pos_label: int = 1, + reduce_group: Any = None, ): """ Args: pos_label: positive label indicator reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='precision_recall_curve', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="precision_recall_curve", + reduce_group=reduce_group, + ) self.pos_label = pos_label def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, + self, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Actual metric computation @@ -197,9 +198,7 @@ def forward( - recall values - threshold values """ - return precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=self.pos_label) + return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) class Precision(TensorMetric): @@ -217,11 +216,10 @@ class Precision(TensorMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -232,11 +230,11 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='precision', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="precision", + reduce_group=reduce_group, + ) self.num_classes = num_classes self.reduction = reduction @@ -251,9 +249,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return precision(pred=pred, target=target, - num_classes=self.num_classes, - reduction=self.reduction) + return precision(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) class Recall(TensorMetric): @@ -271,11 +267,10 @@ class Recall(TensorMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -286,11 +281,11 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='recall', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="recall", + reduce_group=reduce_group, + ) self.num_classes = num_classes self.reduction = reduction @@ -306,10 +301,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return recall(pred=pred, - target=target, - num_classes=self.num_classes, - reduction=self.reduction) + return recall(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) class AveragePrecision(TensorMetric): @@ -327,28 +319,24 @@ class AveragePrecision(TensorMetric): """ def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - reduce_op: Any = None, + self, + pos_label: int = 1, + reduce_group: Any = None, ): """ Args: pos_label: positive label indicator reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='AP', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="AP", + reduce_group=reduce_group, + ) self.pos_label = pos_label def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None + self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None ) -> torch.Tensor: """ Actual metric computation @@ -361,9 +349,7 @@ def forward( Return: torch.Tensor: classification score """ - return average_precision(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=self.pos_label) + return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) class AUROC(TensorMetric): @@ -381,28 +367,24 @@ class AUROC(TensorMetric): """ def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - reduce_op: Any = None, + self, + pos_label: int = 1, + reduce_group: Any = None, ): """ Args: pos_label: positive label indicator reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='auroc', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="auroc", + reduce_group=reduce_group, + ) self.pos_label = pos_label def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None + self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None ) -> torch.Tensor: """ Actual metric computation @@ -415,9 +397,7 @@ def forward( Return: torch.Tensor: classification score """ - return auroc(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=self.pos_label) + return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) class FBeta(TensorMetric): @@ -435,12 +415,11 @@ class FBeta(TensorMetric): """ def __init__( - self, - beta: float, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + beta: float, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -452,11 +431,11 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for DDP reduction """ - super().__init__(name='fbeta', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="fbeta", + reduce_group=reduce_group, + ) self.beta = beta self.num_classes = num_classes @@ -473,9 +452,9 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: classification score """ - return fbeta_score(pred=pred, target=target, - beta=self.beta, num_classes=self.num_classes, - reduction=self.reduction) + return fbeta_score( + pred=pred, target=target, beta=self.beta, num_classes=self.num_classes, reduction=self.reduction + ) class F1(TensorMetric): @@ -493,11 +472,10 @@ class F1(TensorMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -508,11 +486,11 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='f1', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="f1", + reduce_group=reduce_group, + ) self.num_classes = num_classes self.reduction = reduction @@ -528,12 +506,10 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: classification score """ - return f1_score(pred=pred, target=target, - num_classes=self.num_classes, - reduction=self.reduction) + return f1_score(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) -class ROC(TensorCollectionMetric): +class ROC(TensorMetric): """ Computes the Receiver Operator Characteristic (ROC) @@ -553,28 +529,24 @@ class ROC(TensorCollectionMetric): """ def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - reduce_op: Any = None, + self, + pos_label: int = 1, + reduce_group: Any = None, ): """ Args: pos_label: positive label indicator reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='roc', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="roc", + reduce_group=reduce_group, + ) self.pos_label = pos_label def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None + self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Actual metric computation @@ -589,12 +561,10 @@ def forward( - true positive rate - thresholds """ - return roc(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=self.pos_label) + return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) -class MulticlassROC(TensorCollectionMetric): +class MulticlassROC(TensorMetric): """ Computes the multiclass ROC @@ -615,27 +585,27 @@ class MulticlassROC(TensorCollectionMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduce_group: Any = None, ): """ Args: num_classes: number of classes reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='multiclass_roc', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="multiclass_roc", + reduce_group=reduce_group, + ) self.num_classes = num_classes def forward( - self, pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, + self, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Actual metric computation @@ -649,13 +619,19 @@ def forward( tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds """ - return multiclass_roc(pred=pred, - target=target, - sample_weight=sample_weight, - num_classes=self.num_classes) + return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes) + + def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Aggregates results by stacking them instead of concatenating before averaging. + + Returns: + the aggregated results + """ + return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)]) -class MulticlassPrecisionRecallCurve(TensorCollectionMetric): + +class MulticlassPrecisionRecallCurve(TensorMetric): """Computes the multiclass PR Curve Example: @@ -674,29 +650,28 @@ class MulticlassPrecisionRecallCurve(TensorCollectionMetric): """ def __init__( - self, - num_classes: Optional[int] = None, - reduce_group: Any = None, - reduce_op: Any = None, + self, + num_classes: Optional[int] = None, + reduce_group: Any = None, ): """ Args: num_classes: number of classes reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='multiclass_precision_recall_curve', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="multiclass_precision_recall_curve", + reduce_group=reduce_group, + ) self.num_classes = num_classes def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, + self, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Actual metric computation @@ -710,10 +685,18 @@ def forward( tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds """ - return multiclass_precision_recall_curve(pred=pred, - target=target, - sample_weight=sample_weight, - num_classes=self.num_classes) + return multiclass_precision_recall_curve( + pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes + ) + + def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Aggregates results by stacking them instead of concatenating before averaging. + + Returns: + the aggregated results + """ + + return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)]) class DiceCoefficient(TensorMetric): @@ -733,12 +716,12 @@ class DiceCoefficient(TensorMetric): """ def __init__( - self, - include_background: bool = False, - nan_score: float = 0.0, no_fg_score: float = 0.0, - reduction: str = 'elementwise_mean', - reduce_group: Any = None, - reduce_op: Any = None, + self, + include_background: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = "elementwise_mean", + reduce_group: Any = None, ): """ Args: @@ -751,11 +734,11 @@ def __init__( - none: pass array - sum: add elements reduce_group: the process group to reduce metric results from DDP - reduce_op: the operation to perform for ddp reduction """ - super().__init__(name='dice', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name="dice", + reduce_group=reduce_group, + ) self.include_background = include_background self.nan_score = nan_score @@ -773,12 +756,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: the calculated dice coefficient """ - return dice_score(pred=pred, - target=target, - bg=self.include_background, - nan_score=self.nan_score, - no_fg_score=self.no_fg_score, - reduction=self.reduction) + return dice_score( + pred=pred, + target=target, + bg=self.include_background, + nan_score=self.nan_score, + no_fg_score=self.no_fg_score, + reduction=self.reduction, + ) class IoU(TensorMetric): @@ -799,11 +784,7 @@ class IoU(TensorMetric): """ - def __init__( - self, - remove_bg: bool = False, - reduction: str = 'elementwise_mean' - ): + def __init__(self, remove_bg: bool = False, reduction: str = "elementwise_mean"): """ Args: remove_bg: Flag to state whether a background class has been included @@ -817,12 +798,11 @@ def __init__( - none: pass array - sum: add elements """ - super().__init__(name='iou') + super().__init__(name="iou") self.remove_bg = remove_bg self.reduction = reduction - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, - sample_weight: Optional[torch.Tensor] = None): + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None): """ Actual metric calculation. """ diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 2bfe8fa28a61c..b43462cddc5a6 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -18,6 +18,7 @@ sync tensors between different processes in a DDP scenario, when needed. """ +from functools import reduce import numbers from typing import Any, Callable, Optional, Union @@ -31,10 +32,11 @@ try: from torch.distributed import ReduceOp except ImportError: + class ReduceOp: SUM = None - rank_zero_warn('Unsupported `ReduceOp` for distributed computing') + rank_zero_warn("Unsupported `ReduceOp` for distributed computing") def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: @@ -138,8 +140,9 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable: Return: Callable: the decorated function """ - return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate) + return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)( + func_to_decorate + ) def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -185,8 +188,9 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable: Return: Callable: the decorated function """ - return _apply_to_inputs( - apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate) + return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)( + func_to_decorate + ) def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable: @@ -199,8 +203,9 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C Return: Callable: the decorated function """ - return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor)(func_to_decorate) + return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)( + func_to_decorate + ) def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: @@ -240,10 +245,9 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable return _tensor_collection_metric_output_conversion(func_convert_inputs) -def sync_ddp_if_available(result: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None - ) -> torch.Tensor: +def sync_ddp_if_available( + result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -265,14 +269,13 @@ def sync_ddp_if_available(result: Union[torch.Tensor], if reduce_op is None: reduce_op = torch.distributed.ReduceOp.SUM - elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'): + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): reduce_op = torch.distributed.ReduceOp.SUM divide_by_world_size = True # sync all processes before reduction torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=reduce_op, group=group, - async_op=False) + torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) if divide_by_world_size: result = result / torch.distributed.get_world_size(group) @@ -280,8 +283,21 @@ def sync_ddp_if_available(result: Union[torch.Tensor], return result -def gather_all_tensors_if_available(result: Union[torch.Tensor], - group: Optional[Any] = None): +def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + """Makes sure the tensor is at least of 1d shape + + Args: + tensor: the tensor or array to check the shape for + + Returns: + the optionally reshaped tensor + """ + if tensor.shape == (): + tensor = tensor.reshape(1, ) + return tensor + + +def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -312,8 +328,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor], return result -def sync_ddp(group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None) -> Callable: +def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator syncs a functions outputs across different processes for DDP. @@ -327,15 +342,14 @@ def sync_ddp(group: Optional[Any] = None, """ def decorator_fn(func_to_decorate): - return _apply_to_outputs(apply_to_collection, torch.Tensor, - sync_ddp_if_available, group=group, - reduce_op=reduce_op)(func_to_decorate) + return _apply_to_outputs( + apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op + )(func_to_decorate) return decorator_fn -def numpy_metric(group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None) -> Callable: +def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on numpy arrays. It handles the argument conversion and DDP reduction for metrics working on numpy. @@ -357,8 +371,7 @@ def decorator_fn(func_to_decorate): return decorator_fn -def tensor_metric(group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None) -> Callable: +def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors. It handles the argument conversion and DDP reduction for metrics working on tensors. @@ -379,8 +392,7 @@ def decorator_fn(func_to_decorate): return decorator_fn -def tensor_collection_metric(group: Optional[Any] = None, - reduce_op: Optional[ReduceOp] = None) -> Callable: +def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: """ This decorator shall be used on all function metrics working on tensors and returning collections that cannot be converted to tensors. diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 5f61a50e6cd25..45c50b084956f 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Mapping, Optional, Sequence import numbers import torch @@ -21,8 +21,11 @@ import numpy as np from pytorch_lightning.metrics.converters import ( - sync_ddp_if_available, gather_all_tensors_if_available, - convert_to_tensor, convert_to_numpy) + at_least_1d, + gather_all_tensors_if_available, + convert_to_tensor, + convert_to_numpy, +) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -40,32 +43,41 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): * input_convert: pre-forward hook that takes care of input conversion * output_convert: post-forward hook that takes care of output convertion - * ddp_sync: implementation of ddp sync, default is gather all - * aggregate: implement how values should be aggregated + * ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate * compute: post-ddp sync for additional metric computations + ``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary. + + * ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all. + * aggregate: implement how values should be aggregated (defaults to mean). + Call order - input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute + input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute """ - def __init__(self, name: str): + def __init__(self, name: str, reduce_group: Optional[Any] = None): """ Args: name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) """ super().__init__() self.name = name self._dtype = torch.get_default_dtype() - self._device = torch.device('cpu') + self._device = torch.device("cpu") + + self.reduce_group = reduce_group + + self._step_vals = [] # Register hooks self.register_forward_pre_hook(self.input_convert) self.register_forward_hook(self.output_convert) - self.register_forward_hook(self.ddp_sync) - self.register_forward_hook(self.aggregate) + self.register_forward_hook(self.ddp_reduce) self.register_forward_hook(self.compute) @staticmethod @@ -104,12 +116,30 @@ def output_convert(self, data: Any, output: Any): Returns: casted outputs """ - return output + return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d) - @staticmethod - def ddp_sync(self, data: Any, output: Any): + def ddp_sync(self, tensor: Any): """ Implement how the outputs from forward should be synced + (per default just gathers all of them and adds them to self._step_vals) + + Args: + tensor: tensor to sync + + Returns: + synced output + + """ + gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group) + + self._step_vals.append(gathered_tensors) + + return gathered_tensors + + @staticmethod + def ddp_reduce(self, data: Any, output: Any): + """ + Implement how the outputs from forward should be synced and reduced across nodes Args: data: input to forward method @@ -119,27 +149,36 @@ def ddp_sync(self, data: Any, output: Any): synced output """ - return output + synced = self.ddp_sync(output) + return self.aggregate(synced) - @staticmethod - def aggregate(self, data: Any, output: Any): + def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: """ Implement aggregation of values on the same device Args: - data: input to forward method - output: output from the `ddp_sync` hook + tensors: the values to be aggregated Returns: aggregated values """ - return output + try: + return torch.cat(tensors).mean(0) + except (ValueError, TypeError): + if isinstance(tensors[0], Mapping): + return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()} + elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor): + return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) + elif isinstance(tensors[0], torch.Tensor): + return torch.stack(tensors).mean(0) + else: + raise TypeError("unknown metric value format to aggregate") @staticmethod def compute(self, data: Any, output: Any): """ - Implement additionally metric computations to be done after the ddp sync + Implement additionally metric computations to be done after the aggregation Args: data: input to forward method @@ -151,6 +190,15 @@ def compute(self, data: Any, output: Any): """ return output + @property + def aggregated(self) -> torch.Tensor: + aggr = self.aggregate(*self._step_vals) + self.reset() + return self.compute(self, None, aggr) + + def reset(self): + self._step_vals = [] + class TensorMetric(Metric): """ @@ -159,91 +207,20 @@ class TensorMetric(Metric): Already handles DDP sync and input/output conversions. """ - def __init__(self, name: str, - reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): - """ - - Args: - name: the metric's name - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. - """ - super().__init__(name) - self.reduce_group = reduce_group - self.reduce_op = reduce_op - - @staticmethod - def input_convert(self, data: Any): - return apply_to_collection(data, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, - self.dtype, self.device) - - @staticmethod - def output_convert(self, data: Any, output: Any): - return apply_to_collection(output, torch.Tensor, convert_to_tensor, - self.dtype, self.device) - - @staticmethod - def ddp_sync(self, data: Any, output: Any): - return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) - - -class TensorCollectionMetric(Metric): - """ - Base class for metric implementation operating directly on tensors. - All inputs will be casted to tensors if necessary. Outputs won't be casted. - Already handles DDP sync and input conversions. - - This class differs from :class:`TensorMetric`, as it assumes all outputs to - be collections of tensors and does not explicitly convert them. This is - necessary, since some collections (like for ROC, Precision-Recall Curve etc.) - cannot be converted to tensors at the highest level. - All numpy arrays and numbers occuring in these outputs will still be converted. - - Use this class as a baseclass, whenever you want to ensure inputs are - tensors and outputs cannot be converted to tensors automatically - - """ - - def __init__(self, name: str, - reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): - """ - - Args: - name: the metric's name - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. - """ - super().__init__(name) - self.reduce_group = reduce_group - self.reduce_op = reduce_op - @staticmethod def input_convert(self, data: Any): - return apply_to_collection(data, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, - self.dtype, self.device) + data = apply_to_collection( + data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device + ) + return super(TensorMetric, self).input_convert(self, data) @staticmethod def output_convert(self, data: Any, output: Any): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, - self.dtype, self.device) - @staticmethod - def ddp_sync(self, data: Any, output: Any): - return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) + output = apply_to_collection( + output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device + ) + return super(TensorMetric, self).output_convert(self, data, output) class NumpyMetric(Metric): @@ -254,36 +231,15 @@ class NumpyMetric(Metric): Already handles DDP sync and input/output conversions. """ - def __init__(self, name: str, - reduce_group: Optional[Any] = None, - reduce_op: Optional[Any] = None): - """ - - Args: - name: the metric's name - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. - """ - super().__init__(name) - self.reduce_group = reduce_group - self.reduce_op = reduce_op - @staticmethod def input_convert(self, data: Any): - return apply_to_collection(data, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_numpy) + data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) + return super(NumpyMetric, self).input_convert(self, data) @staticmethod def output_convert(self, data: Any, output: Any): - return apply_to_collection(output, - (torch.Tensor, np.ndarray, numbers.Number), - convert_to_tensor, - self.dtype, self.device) + output = apply_to_collection( + output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device + ) - @staticmethod - def ddp_sync(self, data: Any, output: Any): - return apply_to_collection(output, torch.Tensor, sync_ddp_if_available, - self.reduce_group, self.reduce_op) + return super(NumpyMetric, self).output_convert(self, data, output) diff --git a/pytorch_lightning/metrics/sklearns.py b/pytorch_lightning/metrics/sklearns.py index e40e7ec4de87a..ea43c03135cc0 100644 --- a/pytorch_lightning/metrics/sklearns.py +++ b/pytorch_lightning/metrics/sklearns.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -22,15 +22,13 @@ from pytorch_lightning.utilities import rank_zero_warn try: - from torch.distributed import ReduceOp, group + from torch.distributed import group except ImportError: - class ReduceOp: - SUM = None class group: WORLD = None - rank_zero_warn('Unsupported `ReduceOp` for distributed computing.') + rank_zero_warn("Unsupported `ReduceOp` for distributed computing.") class SklearnMetric(NumpyMetric): @@ -45,34 +43,33 @@ class SklearnMetric(NumpyMetric): """ def __init__( - self, - metric_name: str, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, - **kwargs, + self, + metric_name: str, + reduce_group: Any = group.WORLD, + **kwargs, ): """ Args: metric_name: the metric name to import and compute from scikit-learn.metrics reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. **kwargs: additonal keyword arguments (will be forwarded to metric call) """ - super().__init__(name=metric_name, - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__( + name=metric_name, + reduce_group=reduce_group, + ) self.metric_kwargs = kwargs lightning_logger.debug( - f'Metric {self.__class__.__name__} is using Sklearn as backend, meaning that' - ' every metric call will cause a GPU synchronization, which may slow down your code' + f"Metric {self.__class__.__name__} is using Sklearn as backend, meaning that" + " every metric call will cause a GPU synchronization, which may slow down your code" ) @property def metric_fn(self): import sklearn.metrics + return getattr(sklearn.metrics, self.name) def forward(self, *args, **kwargs) -> Union[np.ndarray, int, float]: @@ -103,15 +100,14 @@ class Accuracy(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = Accuracy() >>> metric(y_pred, y_true) - tensor([0.7500]) + tensor(0.7500) """ def __init__( - self, - normalize: bool = True, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + normalize: bool = True, + reduce_group: Any = group.WORLD, ): """ Args: @@ -119,19 +115,14 @@ def __init__( Otherwise, return the fraction of correctly classified samples. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__(metric_name='accuracy_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - normalize=normalize) + super().__init__(metric_name="accuracy_score", reduce_group=reduce_group, normalize=normalize) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Computes the accuracy @@ -160,25 +151,20 @@ class AUC(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = AUC() >>> metric(y_pred, y_true) - tensor([4.]) + tensor(4.) """ def __init__( - self, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + reduce_group: Any = group.WORLD, ): """ Args: reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__(metric_name='auc', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__(metric_name="auc", reduce_group=reduce_group) def forward(self, x: np.ndarray, y: np.ndarray) -> float: """ @@ -202,10 +188,9 @@ class AveragePrecision(SklearnMetric): """ def __init__( - self, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -222,19 +207,14 @@ def __init__( reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('average_precision_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - average=average) + super().__init__("average_precision_score", reduce_group=reduce_group, average=average) def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_score: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -246,12 +226,11 @@ def forward( Return: average precision score """ - return super().forward(y_score=y_score, y_true=y_true, - sample_weight=sample_weight) + return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) class BalancedAccuracy(SklearnMetric): - """ Compute the balanced accuracy score + """Compute the balanced accuracy score Warning: Every metric call will cause a GPU synchronization, which may slow down your code @@ -262,15 +241,14 @@ class BalancedAccuracy(SklearnMetric): >>> y_true = torch.tensor([0, 0, 1, 1]) >>> metric = BalancedAccuracy() >>> metric(y_pred, y_true) - tensor([0.7500]) + tensor(0.7500) """ def __init__( - self, - adjusted: bool = False, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + adjusted: bool = False, + reduce_group: Any = group.WORLD, ): """ Args: @@ -278,19 +256,14 @@ def __init__( corresponds to 0 and perfect performance corresponds to 1 reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('balanced_accuracy_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - adjusted=adjusted) + super().__init__("balanced_accuracy_score", reduce_group=reduce_group, adjusted=adjusted) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -302,9 +275,7 @@ def forward( balanced accuracy score """ - return super().forward(y_true=y_true, - y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class CohenKappaScore(SklearnMetric): @@ -317,16 +288,15 @@ class CohenKappaScore(SklearnMetric): >>> y_true = torch.tensor([2, 2, 2, 1]) >>> metric = CohenKappaScore() >>> metric(y_pred, y_true) - tensor([-0.3333]) + tensor(-0.3333) """ def __init__( - self, - labels: Optional[Sequence] = None, - weights: Optional[str] = None, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + weights: Optional[str] = None, + reduce_group: Any = group.WORLD, ): """ Args: @@ -339,20 +309,14 @@ def __init__( and ``quadratic`` means quadratic weighted reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('cohen_kappa_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels, - weights=weights) + super().__init__("cohen_kappa_score", reduce_group=reduce_group, labels=labels, weights=weights) def forward( - self, - y1: np.ndarray, - y2: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y1: np.ndarray, + y2: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -386,10 +350,9 @@ class ConfusionMatrix(SklearnMetric): """ def __init__( - self, - labels: Optional[Sequence] = None, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + reduce_group: Any = group.WORLD, ): """ Args: @@ -399,13 +362,8 @@ def __init__( in ``y_true`` or ``y_pred`` are used in sorted order. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('confusion_matrix', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels) + super().__init__("confusion_matrix", reduce_group=reduce_group, labels=labels) def forward(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ @@ -419,9 +377,12 @@ def forward(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ return super().forward(y_pred=y_pred, y_true=y_true) + def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: + return torch.stack(tensors).mean(0) + class DCG(SklearnMetric): - """ Compute discounted cumulative gain + """Compute discounted cumulative gain Warning: Every metric call will cause a GPU synchronization, which may slow down your code @@ -432,16 +393,15 @@ class DCG(SklearnMetric): >>> y_true = torch.tensor([[10, 0, 0, 1, 5]]) >>> metric = DCG() >>> metric(y_score, y_true) - tensor([9.4995]) + tensor(9.4995) """ def __init__( - self, - k: Optional[int] = None, - log_base: float = 2, - ignore_ties: bool = False, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + k: Optional[int] = None, + log_base: float = 2, + ignore_ties: bool = False, + reduce_group: Any = group.WORLD, ): """ Args: @@ -450,21 +410,14 @@ def __init__( ignore_ties: If ``True``, assume there are no ties in y_score for efficiency gains reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('dcg_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - k=k, - log_base=log_base, - ignore_ties=ignore_ties) + super().__init__("dcg_score", reduce_group=reduce_group, k=k, log_base=log_base, ignore_ties=ignore_ties) def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_score: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -477,9 +430,7 @@ def forward( DCG score """ - return super().forward(y_true=y_true, - y_score=y_score, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_score=y_score, sample_weight=sample_weight) class F1(SklearnMetric): @@ -503,7 +454,7 @@ class F1(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = F1() >>> metric(y_pred, y_true) - tensor([0.6667]) + tensor(0.6667) References - [1] `Wikipedia entry for the F1-score @@ -511,12 +462,11 @@ class F1(SklearnMetric): """ def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -550,21 +500,14 @@ def __init__( behavior is deprecated and will change in version 0.18. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('f1_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels, - pos_label=pos_label, - average=average) + super().__init__("f1_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -593,7 +536,7 @@ class FBeta(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = FBeta(beta=0.25) >>> metric(y_pred, y_true) - tensor([0.7361]) + tensor(0.7361) References: - [1] R. Baeza-Yates and B. Ribeiro-Neto (2011). @@ -603,13 +546,12 @@ class FBeta(SklearnMetric): """ def __init__( - self, - beta: float, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + beta: float, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -644,22 +586,16 @@ def __init__( behavior is deprecated and will change in version 0.18. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('fbeta_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - beta=beta, - labels=labels, - pos_label=pos_label, - average=average) + super().__init__( + "fbeta_score", reduce_group=reduce_group, beta=beta, labels=labels, pos_label=pos_label, average=average + ) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -685,32 +621,27 @@ class Hamming(SklearnMetric): >>> y_true = torch.tensor([1, 1, 2, 3]) >>> metric = Hamming() >>> metric(y_pred, y_true) - tensor([0.2500]) + tensor(0.2500) """ def __init__( - self, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + reduce_group: Any = group.WORLD, ): """ Args: reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('hamming_loss', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__("hamming_loss", reduce_group=reduce_group) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -735,34 +666,28 @@ class Hinge(SklearnMetric): >>> y_true = torch.tensor([1, 1, 0, 0]) >>> metric = Hinge() >>> metric(pred_decision, y_true) - tensor([1.6300]) + tensor(1.6300) """ def __init__( - self, - labels: Optional[Sequence] = None, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + reduce_group: Any = group.WORLD, ): """ Args: labels: Integer array of labels. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('hinge_loss', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels) + super().__init__("hinge_loss", reduce_group=reduce_group, labels=labels) def forward( - self, - pred_decision: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + pred_decision: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -774,9 +699,7 @@ def forward( Average hinge loss """ - return super().forward(pred_decision=pred_decision, - y_true=y_true, - sample_weight=sample_weight) + return super().forward(pred_decision=pred_decision, y_true=y_true, sample_weight=sample_weight) class Jaccard(SklearnMetric): @@ -789,17 +712,16 @@ class Jaccard(SklearnMetric): >>> y_true = torch.tensor([0, 1, 1]) >>> metric = Jaccard() >>> metric(y_pred, y_true) - tensor([0.3333]) + tensor(0.3333) """ def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -833,21 +755,16 @@ def __init__( behavior is deprecated and will change in version 0.18. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('jaccard_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels, - pos_label=pos_label, - average=average) + super().__init__( + "jaccard_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average + ) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -877,17 +794,16 @@ class Precision(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = Precision() >>> metric(y_pred, y_true) - tensor([0.7500]) + tensor(0.7500) """ def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -921,21 +837,16 @@ def __init__( behavior is deprecated and will change in version 0.18. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('precision_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels, - pos_label=pos_label, - average=average) + super().__init__( + "precision_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average + ) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -965,17 +876,16 @@ class Recall(SklearnMetric): >>> y_true = torch.tensor([0, 1, 2, 2]) >>> metric = Recall() >>> metric(y_pred, y_true) - tensor([0.6250]) + tensor(0.6250) """ def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1009,21 +919,14 @@ def __init__( behavior is deprecated and will change in version 0.18. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('recall_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - labels=labels, - pos_label=pos_label, - average=average) + super().__init__("recall_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -1059,29 +962,23 @@ class PrecisionRecallCurve(SklearnMetric): """ def __init__( - self, - pos_label: Union[str, int] = 1, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + pos_label: Union[str, int] = 1, + reduce_group: Any = group.WORLD, ): """ Args: pos_label: The class to report if ``average='binary'``. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('precision_recall_curve', - reduce_group=reduce_group, - reduce_op=reduce_op, - pos_label=pos_label) + super().__init__("precision_recall_curve", reduce_group=reduce_group, pos_label=pos_label) def forward( - self, - probas_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + probas_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -1103,9 +1000,15 @@ def forward( """ # only return x and y here, since for now we cannot auto-convert elements of multiple length. # Will be fixed in native implementation - return np.array(super().forward(probas_pred=probas_pred, - y_true=y_true, - sample_weight=sample_weight)[:2]) + return np.array(super().forward(probas_pred=probas_pred, y_true=y_true, sample_weight=sample_weight)[:2]) + + def aggregate(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Aggregates results by stacking them instead of concatenating before averaging. + + Returns: + the aggregated results + """ + return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) class ROC(SklearnMetric): @@ -1136,29 +1039,23 @@ class ROC(SklearnMetric): """ def __init__( - self, - pos_label: Union[str, int] = 1, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + pos_label: Union[str, int] = 1, + reduce_group: Any = group.WORLD, ): """ Args: pos_labels: The class to report if ``average='binary'``. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('roc_curve', - reduce_group=reduce_group, - reduce_op=reduce_op, - pos_label=pos_label) + super().__init__("roc_curve", reduce_group=reduce_group, pos_label=pos_label) def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_score: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -1182,6 +1079,15 @@ class or confidence values. """ return np.array(super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight)[:2]) + def aggregate(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Aggregates results by stacking them instead of concatenating before averaging. + + Returns: + the aggregated results + """ + + return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) + class AUROC(SklearnMetric): """ @@ -1197,10 +1103,9 @@ class AUROC(SklearnMetric): """ def __init__( - self, - average: Optional[str] = 'macro', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + average: Optional[str] = "macro", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1217,19 +1122,14 @@ def __init__( reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('roc_auc_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - average=average) + super().__init__("roc_auc_score", reduce_group=reduce_group, average=average) def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_score: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -1241,8 +1141,7 @@ def forward( Return: Area Under Receiver Operating Characteristic Curve """ - return super().forward(y_score=y_score, y_true=y_true, - sample_weight=sample_weight) + return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) class ExplainedVariance(SklearnMetric): @@ -1258,14 +1157,13 @@ class ExplainedVariance(SklearnMetric): >>> y_true = torch.tensor([3, -0.5, 2, 7]) >>> metric = ExplainedVariance() >>> metric(y_pred, y_true) - tensor([0.9572]) + tensor(0.9572) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'variance_weighted', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "variance_weighted", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1274,19 +1172,14 @@ def __init__( output values should be aggregated. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('explained_variance_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("explained_variance_score", reduce_group=reduce_group, multioutput=multioutput) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1298,8 +1191,7 @@ def forward( Explained variance score """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MeanAbsoluteError(SklearnMetric): @@ -1315,15 +1207,14 @@ class MeanAbsoluteError(SklearnMetric): >>> y_true = torch.tensor([3, -0.5, 2, 7]) >>> metric = MeanAbsoluteError() >>> metric(y_pred, y_true) - tensor([0.5000]) + tensor(0.5000) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'uniform_average', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "uniform_average", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1332,16 +1223,10 @@ def __init__( output values should be aggregated. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_absolute_error', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("mean_absolute_error", reduce_group=reduce_group, multioutput=multioutput) - def forward(self, y_pred: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None): + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, sample_weight: Optional[np.ndarray] = None): """ Args: y_pred: Estimated target values @@ -1352,9 +1237,7 @@ def forward(self, y_pred: np.ndarray, y_true: np.ndarray, Mean absolute error """ - return super().forward(y_true=y_true, - y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MeanSquaredError(SklearnMetric): @@ -1370,19 +1253,18 @@ class MeanSquaredError(SklearnMetric): >>> y_true = torch.tensor([3, -0.5, 2, 7]) >>> metric = MeanSquaredError() >>> metric(y_pred, y_true) - tensor([0.3750]) + tensor(0.3750) >>> metric = MeanSquaredError(squared=True) >>> metric(y_pred, y_true) - tensor([0.6124]) + tensor(0.6124) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'uniform_average', - squared: bool = False, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "uniform_average", + squared: bool = False, + reduce_group: Any = group.WORLD, ): """ Args: @@ -1392,20 +1274,15 @@ def __init__( squared: if ``True`` returns the mse value else the rmse value reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_squared_error', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("mean_squared_error", reduce_group=reduce_group, multioutput=multioutput) self.squared = squared def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1417,8 +1294,7 @@ def forward( Mean squared error """ - mse = super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + mse = super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) if self.squared: mse = np.sqrt(mse) return mse @@ -1437,14 +1313,13 @@ class MeanSquaredLogError(SklearnMetric): >>> y_true = torch.tensor([3, 5, 2.5, 7]) >>> metric = MeanSquaredLogError() >>> metric(y_pred, y_true) - tensor([0.0397]) + tensor(0.0397) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'uniform_average', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "uniform_average", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1453,19 +1328,14 @@ def __init__( output values should be aggregated. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_squared_log_error', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("mean_squared_log_error", reduce_group=reduce_group, multioutput=multioutput) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1477,8 +1347,7 @@ def forward( Mean squared log error """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MedianAbsoluteError(SklearnMetric): @@ -1494,14 +1363,13 @@ class MedianAbsoluteError(SklearnMetric): >>> y_true = torch.tensor([3, -0.5, 2, 7]) >>> metric = MedianAbsoluteError() >>> metric(y_pred, y_true) - tensor([0.5000]) + tensor(0.5000) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'uniform_average', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "uniform_average", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1510,13 +1378,8 @@ def __init__( output values should be aggregated. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('median_absolute_error', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("median_absolute_error", reduce_group=reduce_group, multioutput=multioutput) def forward(self, y_pred: np.ndarray, y_true: np.ndarray): """ @@ -1544,14 +1407,13 @@ class R2Score(SklearnMetric): >>> y_true = torch.tensor([3, -0.5, 2, 7]) >>> metric = R2Score() >>> metric(y_pred, y_true) - tensor([0.9486]) + tensor(0.9486) """ def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = 'uniform_average', - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + multioutput: Optional[Union[str, List[float]]] = "uniform_average", + reduce_group: Any = group.WORLD, ): """ Args: @@ -1560,19 +1422,14 @@ def __init__( output values should be aggregated. reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('r2_score', - reduce_group=reduce_group, - reduce_op=reduce_op, - multioutput=multioutput) + super().__init__("r2_score", reduce_group=reduce_group, multioutput=multioutput) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1584,8 +1441,7 @@ def forward( R^2 score """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MeanPoissonDeviance(SklearnMetric): @@ -1601,30 +1457,25 @@ class MeanPoissonDeviance(SklearnMetric): >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) >>> metric = MeanPoissonDeviance() >>> metric(y_pred, y_true) - tensor([0.9034]) + tensor(0.9034) """ def __init__( - self, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + reduce_group: Any = group.WORLD, ): """ Args: reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_poisson_deviance', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__("mean_poisson_deviance", reduce_group=reduce_group) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1636,8 +1487,7 @@ def forward( Mean possion deviance """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MeanGammaDeviance(SklearnMetric): @@ -1653,30 +1503,25 @@ class MeanGammaDeviance(SklearnMetric): >>> y_true = torch.tensor([2, 0.5, 1, 4]) >>> metric = MeanGammaDeviance() >>> metric(y_pred, y_true) - tensor([1.0569]) + tensor(1.0569) """ def __init__( - self, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + reduce_group: Any = group.WORLD, ): """ Args: reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_gamma_deviance', - reduce_group=reduce_group, - reduce_op=reduce_op) + super().__init__("mean_gamma_deviance", reduce_group=reduce_group) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1688,8 +1533,7 @@ def forward( Mean gamma deviance """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) class MeanTweedieDeviance(SklearnMetric): @@ -1705,14 +1549,13 @@ class MeanTweedieDeviance(SklearnMetric): >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) >>> metric = MeanTweedieDeviance() >>> metric(y_pred, y_true) - tensor([1.8125]) + tensor(1.8125) """ def __init__( - self, - power: float = 0, - reduce_group: Any = group.WORLD, - reduce_op: Any = ReduceOp.SUM, + self, + power: float = 0, + reduce_group: Any = group.WORLD, ): """ Args: @@ -1729,19 +1572,14 @@ def __init__( reduce_group: the process group for DDP reduces (only needed for DDP training). Defaults to all processes (world) - reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). - Defaults to sum. """ - super().__init__('mean_tweedie_deviance', - reduce_group=reduce_group, - reduce_op=reduce_op, - power=power) + super().__init__("mean_tweedie_deviance", reduce_group=reduce_group, power=power) def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, ): """ Args: @@ -1753,5 +1591,4 @@ def forward( Mean tweedie deviance """ - return super().forward(y_true=y_true, y_pred=y_pred, - sample_weight=sample_weight) + return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 5985745bfa070..5a8589f8f254a 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,48 +1,49 @@ import os +from typing import Any import numpy as np import pytest import torch import tests.base.develop_utils as tutils from tests.base import EvalModelTemplate -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric from pytorch_lightning import Trainer class DummyTensorMetric(TensorMetric): def __init__(self): - super().__init__('dummy') + super().__init__("dummy") def forward(self, input1, input2): assert isinstance(input1, torch.Tensor) assert isinstance(input2, torch.Tensor) - return torch.tensor([1.]) + return torch.tensor([1.0]) class DummyNumpyMetric(NumpyMetric): def __init__(self): - super().__init__('dummy') + super().__init__("dummy") def forward(self, input1, input2): assert isinstance(input1, np.ndarray) assert isinstance(input2, np.ndarray) - return 1. + return 1.0 -class DummyTensorCollectionMetric(TensorCollectionMetric): +class DummyTensorCollectionMetric(TensorMetric): def __init__(self): - super().__init__('dummy') + super().__init__("dummy") def forward(self, input1, input2): assert isinstance(input1, torch.Tensor) assert isinstance(input2, torch.Tensor) - return 1., 2., 3., 4. + return 1.0, 2.0, 3.0, 4.0 -@pytest.mark.parametrize('metric', [DummyTensorCollectionMetric()]) +@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()]) def test_collection_metric(metric: Metric): """ Test that metric.device, metric.dtype works for metric collection """ - input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + input1, input2 = torch.tensor([1.0]), torch.tensor([2.0]) def change_and_check_device_dtype(device, dtype): metric.to(device=device, dtype=dtype) @@ -56,9 +57,9 @@ def change_and_check_device_dtype(device, dtype): if dtype is not None: assert metric.dtype == dtype - devices = [None, 'cpu'] + devices = [None, "cpu"] if torch.cuda.is_available(): - devices += ['cuda:0'] + devices += ["cuda:0"] for device in devices: for dtype in [None, torch.float32, torch.float64]: @@ -66,10 +67,10 @@ def change_and_check_device_dtype(device, dtype): if torch.cuda.is_available(): metric.cuda(0) - assert metric.device == torch.device('cuda', index=0) + assert metric.device == torch.device("cuda", index=0) metric.cpu() - assert metric.device == torch.device('cpu') + assert metric.device == torch.device("cpu") metric.type(torch.int8) assert metric.dtype == torch.int8 @@ -87,13 +88,16 @@ def change_and_check_device_dtype(device, dtype): assert metric.dtype == torch.float16 -@pytest.mark.parametrize('metric', [ - DummyTensorMetric(), - DummyNumpyMetric(), -]) +@pytest.mark.parametrize( + "metric", + [ + DummyTensorMetric(), + DummyNumpyMetric(), + ], +) def test_metric(metric: Metric): """ Test that metric.device, metric.dtype works for single metric""" - input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + input1, input2 = torch.tensor([1.0]), torch.tensor([2.0]) def change_and_check_device_dtype(device, dtype): metric.to(device=device, dtype=dtype) @@ -109,9 +113,9 @@ def change_and_check_device_dtype(device, dtype): assert metric.dtype == dtype assert metric_val.dtype == dtype - devices = [None, 'cpu'] + devices = [None, "cpu"] if torch.cuda.is_available(): - devices += ['cuda:0'] + devices += ["cuda:0"] for device in devices: for dtype in [None, torch.float32, torch.float64]: @@ -119,16 +123,12 @@ def change_and_check_device_dtype(device, dtype): if torch.cuda.is_available(): metric.cuda(0) - assert metric.device == torch.device('cuda', index=0) - assert metric(input1, input2).device == torch.device('cuda', index=0) + assert metric.device == torch.device("cuda", index=0) + assert metric(input1, input2).device == torch.device("cuda", index=0) metric.cpu() - assert metric.device == torch.device('cpu') - assert metric(input1, input2).device == torch.device('cpu') - - metric.type(torch.int8) - assert metric.dtype == torch.int8 - assert metric(input1, input2).dtype == torch.int8 + assert metric.device == torch.device("cpu") + assert metric(input1, input2).device == torch.device("cpu") metric.float() assert metric.dtype == torch.float32 @@ -156,7 +156,7 @@ def test_model_pickable(tmpdir, metric: Metric): max_epochs=1, limit_train_batches=10, gpus=[0, 1], - distributed_backend='ddp_spawn', + distributed_backend="ddp_spawn", ) model = EvalModelTemplate() @@ -167,17 +167,19 @@ def test_model_pickable(tmpdir, metric: Metric): result = trainer.fit(model) # correct result and ok accuracy - assert result == 1, 'ddp model failed to complete' + assert result == 1, "ddp model failed to complete" @pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()]) def test_saving_pickable(tmpdir, metric: Metric): """ Make sure that metrics are pickable by saving and loading them using torch """ - x, y = torch.randn(10,), torch.randn(10,) + x, y = torch.randn(10,), torch.randn( + 10, + ) results_before_save = metric(x, y) # save metric - save_path = os.path.join(tmpdir, 'save_test.ckpt') + save_path = os.path.join(tmpdir, "save_test.ckpt") torch.save(metric, save_path) # load metric @@ -186,3 +188,125 @@ def test_saving_pickable(tmpdir, metric: Metric): # Check metric value is the same assert results_before_save == results_after_load + + +def check_call_order(): + class DummyMetric(Metric): + def __init__(self): + super().__init__("dummy") + self.call_history = ["init"] + + @staticmethod + def input_convert(self, data: Any): + self.call_history.append("input_convert") + return super(DummyMetric, self).input_convert(self, data) + + def forward(self, tensor1, tensor2): + self.call_history.append("forward") + return tensor1 - tensor2 + + @staticmethod + def output_convert(self, data: Any, output: Any): + self.call_history.append("output_convert") + return super(DummyMetric, self).output_convert(self, data, output) + + def ddp_sync(self, tensor: Any): + self.call_history.append("ddp_sync") + return super().ddp_sync(tensor) + + @staticmethod + def ddp_reduce(self, data: Any, output: Any): + self.call_history.append("ddp_reduce") + return super(DummyMetric, self).ddp_reduce(self, data, output) + + def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: + self.call_history.append("aggregate") + return super().aggregate(*tensors) + + def reset(self): + self.call_history.append("reset") + return super().reset() + + @property + def aggregated(self) -> torch.Tensor: + self.call_history.append("aggregated") + return super().aggregated + + @staticmethod + def compute(self, data: Any, output: Any): + self.call_history.append("compute") + return super(DummyMetric, self).compute(self, data, output) + + metric = DummyMetric() + assert metric.call_history == ["init"] + result = metric(torch.tensor([2.0]), torch.tensor([1.0])) + assert torch.allclose(result, torch.tensor(1.0)) + assert metric.call_history == [ + "init", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + ] + aggr = metric.aggregated + assert metric.call_history == [ + "init", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + "aggregated", + "aggregate", + "reset", + ] + assert torch.allclose(aggr, result) + _ = metric(torch.tensor(2.0), torch.tensor(1.0)) + assert metric.call_history == [ + "init", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + "aggregated", + "aggregate", + "reset", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + ] + + metric = DummyMetric() + _ = metric(torch.tensor([2.0]), torch.tensor([1.0])) + _ = metric(torch.tensor([3.0]), torch.tensor([0.0])) + + aggregated = metric.aggregated + + assert torch.allclose(aggregated, torch.tensor(2.0)) + + assert metric.call_history == [ + "init", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + "input_convert", + "forward", + "output_convert", + "ddp_reduce", + "ddp_sync", + "aggregate", + "aggregated", + "aggregate", + "reset", + ] diff --git a/tests/metrics/test_sklearn.py b/tests/metrics/test_sklearn.py index 10b57417411c4..019048056016c 100644 --- a/tests/metrics/test_sklearn.py +++ b/tests/metrics/test_sklearn.py @@ -167,13 +167,12 @@ def test_sklearn_metric(metric_class, sklearn_func, inputs): sklearn_result = sklearn_func(**numpy_inputs) lightning_result = metric_class(**inputs) - assert np.allclose(sklearn_result, lightning_result, atol=1e-5) sklearn_result = apply_to_collection( sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - lightning_result = apply_to_collection( - lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) + lightning_result = np.array(apply_to_collection( + lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) assert isinstance(lightning_result, type(sklearn_result))