diff --git a/cyclops/evaluate/metrics/functional/__init__.py b/cyclops/evaluate/metrics/functional/__init__.py index 14eee5b9e..ed64063d1 100644 --- a/cyclops/evaluate/metrics/functional/__init__.py +++ b/cyclops/evaluate/metrics/functional/__init__.py @@ -37,12 +37,14 @@ recall, ) from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # noqa: F401 + PRCurve, binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, precision_recall_curve, ) from cyclops.evaluate.metrics.functional.roc import ( # noqa: F401 + ROCCurve, binary_roc_curve, multiclass_roc_curve, multilabel_roc_curve, diff --git a/cyclops/evaluate/metrics/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/functional/precision_recall_curve.py index bbe0bcc78..a0f9b69e3 100644 --- a/cyclops/evaluate/metrics/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/functional/precision_recall_curve.py @@ -1,6 +1,6 @@ """Functions for computing the precision-recall curve for different input types.""" -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -15,6 +15,14 @@ ) +class PRCurve(NamedTuple): + """Named tuple with Precision-Recall curve (Precision, Recall and thresholds).""" + + precision: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + recall: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _format_thresholds( thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, ) -> Optional[npt.NDArray[np.float_]]: @@ -279,7 +287,7 @@ def binary_precision_recall_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> PRCurve: """Compute precision-recall curve for binary input. Parameters @@ -301,13 +309,10 @@ def binary_precision_recall_curve( Returns ------- - precision : numpy.ndarray - Precision scores such that element i is the precision of predictions - with score >= thresholds[i]. - recall : numpy.ndarray - Recall scores in descending order. - thresholds : numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRCurve + A named tuple containing the precision (element i is the precision of predictions + with score >= thresholds[i]), recall (scores in descending order) + and thresholds used to compute the precision-recall curve. Examples -------- @@ -335,13 +340,14 @@ def binary_precision_recall_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) - - return _binary_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _binary_precision_recall_curve_compute( state, thresholds, pos_label=pos_label, ) + return PRCurve(precision_, recall_, thresholds_) + def _multiclass_precision_recall_curve_format( target: npt.ArrayLike, @@ -572,14 +578,7 @@ def multiclass_precision_recall_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multiclass problems. Parameters @@ -600,18 +599,13 @@ def multiclass_precision_recall_curve( Returns ------- - precision : numpy.ndarray or list of numpy.ndarray - Precision scores where element i is the precision score corresponding - to the threshold i. If state is a tuple of the target and predicted - probabilities, then precision is a list of arrays, where each array - corresponds to the precision scores for a class. - recall : numpy.ndarray or list of numpy.ndarray - Recall scores where element i is the recall score corresponding to - the threshold i. If state is a tuple of the target and predicted - probabilities, then recall is a list of arrays, where each array - corresponds to the recall scores for a class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRcurve + A named tuple containing the precision, recall, and thresholds. + Precision and recall are arrays where element i is the precision and + recall score corresponding to threshold i. If state is a tuple of the + target and predicted probabilities, then precision and recall are lists + of arrays, where each array corresponds to the precision and recall + scores for a class. Examples -------- @@ -652,11 +646,12 @@ def multiclass_precision_recall_curve( thresholds=thresholds, ) - return _multiclass_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multiclass_precision_recall_curve_compute( state, thresholds, # type: ignore num_classes, ) + return PRCurve(precision_, recall_, thresholds_) def _multilabel_precision_recall_curve_format( @@ -868,14 +863,7 @@ def multilabel_precision_recall_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multilabel input. Parameters @@ -897,16 +885,18 @@ def multilabel_precision_recall_curve( Returns ------- - precision : numpy.ndarray or List[numpy.ndarray] + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. Precision values for each label. If ``thresholds`` is None, then precision is a list of arrays, one for each label. Otherwise, precision is a single array with shape (``num_labels``, len(``thresholds``)). - recall : numpy.ndarray or List[numpy.ndarray] + - ``recall``: numpy.ndarray or List[numpy.ndarray]. Recall values for each label. If ``thresholds`` is None, then recall is a list of arrays, one for each label. Otherwise, recall is a single array with shape (``num_labels``, len(``thresholds``)). - thresholds : numpy.ndarray or List[numpy.ndarray] + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. If ``thresholds`` is None, then thresholds is a list of arrays, one for each label. Otherwise, thresholds is a single array with shape (len(``thresholds``,). @@ -950,11 +940,12 @@ def multilabel_precision_recall_curve( thresholds=thresholds, ) - return _multilabel_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multilabel_precision_recall_curve_compute( state, thresholds, # type: ignore num_labels, ) + return PRCurve(precision_, recall_, thresholds_) def precision_recall_curve( @@ -965,14 +956,7 @@ def precision_recall_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for different tasks/input types. Parameters @@ -997,17 +981,19 @@ def precision_recall_curve( Returns ------- - precision : numpy.ndarray + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. The precision scores where ``precision[i]`` is the precision score for - ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', + ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel', then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the precision scores for class or label ``i``. - recall : numpy.ndarray + - ``recall``: numpy.ndarray or List[numpy.ndarray]. The recall scores where ``recall[i]`` is the recall score for ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then ``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall scores for class or label ``i``. - thresholds : numpy.ndarray + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. Thresholds used for computing the precision and recall scores. Raises diff --git a/cyclops/evaluate/metrics/functional/roc.py b/cyclops/evaluate/metrics/functional/roc.py index 24eb1dd16..b8935ded5 100644 --- a/cyclops/evaluate/metrics/functional/roc.py +++ b/cyclops/evaluate/metrics/functional/roc.py @@ -1,6 +1,6 @@ """Functions for computing the receiver operating characteristic (ROC) curve.""" import logging -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -23,6 +23,14 @@ setup_logging(print_level="WARN", logger=LOGGER) +class ROCCurve(NamedTuple): + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" + + fpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + tpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _roc_compute_from_confmat( confmat: npt.NDArray[Any], thresholds: npt.NDArray[np.float_], @@ -144,7 +152,7 @@ def binary_roc_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> ROCCurve: """Compute the ROC curve for binary classification tasks. Parameters @@ -166,12 +174,9 @@ def binary_roc_curve( Returns ------- - fpr : numpy.ndarray - False positive rate. - tpr : numpy.ndarray - True positive rate. - thresholds : numpy.ndarray - Thresholds used to compute fpr and tpr. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. Examples -------- @@ -197,8 +202,9 @@ def binary_roc_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) + fpr, tpr, thresholds = _binary_roc_compute(state, thresholds, pos_label) - return _binary_roc_compute(state, thresholds=thresholds, pos_label=pos_label) + return ROCCurve(fpr, tpr, thresholds) def _multiclass_roc_compute( @@ -272,14 +278,7 @@ def multiclass_roc_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multiclass classification tasks. Parameters @@ -301,17 +300,11 @@ def multiclass_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - class. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each class. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -352,8 +345,9 @@ def multiclass_roc_curve( num_classes=num_classes, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multiclass_roc_compute(state, num_classes, thresholds) - return _multiclass_roc_compute(state, num_classes, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def _multilabel_roc_compute( @@ -427,14 +421,7 @@ def multilabel_roc_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multilabel classification tasks. Parameters @@ -456,17 +443,11 @@ def multilabel_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -502,8 +483,9 @@ def multilabel_roc_curve( num_labels=num_labels, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multilabel_roc_compute(state, num_labels, thresholds) - return _multilabel_roc_compute(state, num_labels, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def roc_curve( @@ -514,14 +496,7 @@ def roc_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for different tasks/input types. Parameters @@ -558,22 +533,11 @@ def roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``fpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``fpr`` is a list of 1d numpy - arrays, one for each class or label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``tpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``tpr`` is a list of 1d numpy - arrays, one for each class or label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. If ``task`` is 'binary' or - ``threshold`` is not None, ``thresholds`` is a 1d numpy array. If - ``task`` is 'multiclass' or 'multilabel', and ``threshold`` is None, - then ``thresholds`` is a list of 1d numpy arrays, one for each class - or label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Raises ------ diff --git a/cyclops/evaluate/metrics/precision_recall_curve.py b/cyclops/evaluate/metrics/precision_recall_curve.py index 64bf08833..9a5ce76b5 100644 --- a/cyclops/evaluate/metrics/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/precision_recall_curve.py @@ -1,11 +1,12 @@ """Classes for computing precision-recall curves.""" -from typing import Any, List, Literal, Optional, Tuple, Type, Union +from typing import Any, List, Literal, Optional, Type, Union import numpy as np import numpy.typing as npt from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # type: ignore # noqa: E501 + PRCurve, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, _binary_precision_recall_curve_update, @@ -42,14 +43,14 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryPrecisionRecallCurve(thresholds=3) >>> metric(target, preds) - (array([0.5, 1. , 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5, 1. , 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 0, 1], [1, 1, 0, 0]] >>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.66666667, 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5 , 0.66666667, 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) """ @@ -101,7 +102,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -111,11 +112,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _binary_precision_recall_curve_compute( + precision, recall, thresholds = _binary_precision_recall_curve_compute( state=state, thresholds=self.thresholds, pos_label=self.pos_label, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -181,11 +183,11 @@ class MulticlassPrecisionRecallCurve( >>> preds = [[0.1, 0.6, 0.3], [0.05, 0.95, 0.0], [0.5, 0.3, 0.2], [0.2, 0.5, 0.3]] >>> metric = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.5 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.5 , 0. , 0. , 1. ], [0.25 , 0.33333333, 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 0., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 0., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 0], [1, 2, 0, 1]] >>> preds = [ @@ -195,11 +197,11 @@ class MulticlassPrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.375, 0.5 , 0. , 1. ], + PRCurve(precision=array([[0.375, 0.5 , 0. , 1. ], [0.375, 0.4 , 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1. , 0.33333333, 0. , 0. ], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1. , 0.33333333, 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0. , 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0. , 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -253,14 +255,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -270,11 +265,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multiclass_precision_recall_curve_compute( + precision, recall, thresholds = _multiclass_precision_recall_curve_compute( state=state, thresholds=self.thresholds, # type: ignore[arg-type] num_classes=self.num_classes, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -340,18 +336,18 @@ class MultilabelPrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = MultilabelPrecisionRecallCurve(num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -405,14 +401,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -422,11 +411,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_precision_recall_curve_compute( + precision, recall, thresholds = _multilabel_precision_recall_curve_compute( state, thresholds=self.thresholds, # type: ignore[arg-type] num_labels=self.num_labels, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -502,15 +492,15 @@ class PrecisionRecallCurve( >>> preds = [0.6, 0.2, 0.3, 0.8] >>> metric = PrecisionRecallCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), array([1. , 0.66666667, 0.33333333, 0. , 0. ]), array([0.2, 0.3, 0.6, 0.8])) + PRCurve(precision=array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), recall=array([1. , 0.66666667, 0.33333333, 0. , 0. ]), thresholds=array([0.2, 0.3, 0.6, 0.8])) >>> metric.reset_state() >>> target = [[1, 0, 1, 1], [0, 0, 0, 1]] >>> preds = [[0.5, 0.4, 0.1, 0.3], [0.9, 0.6, 0.45, 0.8]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , - 0.33333333, 0.5 , 0. , 1. ]), array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) + PRCurve(precision=array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , + 0.33333333, 0.5 , 0. , 1. ]), recall=array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), thresholds=array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -518,11 +508,11 @@ class PrecisionRecallCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.2, 0.2, 0.6]] >>> metric = PrecisionRecallCurve(task="multiclass", num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.25, 0. , 0. , 1. ], + PRCurve(precision=array([[0.25, 0. , 0. , 1. ], [0.25, 0.5 , 0. , 1. ], - [0.5 , 1. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.5 , 1. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 2], [1, 2, 0, 1]] >>> preds = [ @@ -532,11 +522,11 @@ class PrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.25 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.25 , 0. , 0. , 1. ], [0.375, 0.5 , 0. , 1. ], - [0.375, 0.5 , 0. , 1. ]]), array([[1. , 0. , 0. , 0. ], + [0.375, 0.5 , 0. , 1. ]]), recall=array([[1. , 0. , 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0.66666667, 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0.66666667, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -544,18 +534,18 @@ class PrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = PrecisionRecallCurve(task="multilabel", num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.1, 0.9], [0.8, 0.2]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ diff --git a/cyclops/evaluate/metrics/roc.py b/cyclops/evaluate/metrics/roc.py index 1b2774c8b..58587856c 100644 --- a/cyclops/evaluate/metrics/roc.py +++ b/cyclops/evaluate/metrics/roc.py @@ -1,10 +1,13 @@ """Classes for computing ROC metrics.""" -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Union import numpy as np import numpy.typing as npt +from cyclops.evaluate.metrics.functional.roc import ( + ROCCurve as ROCCurveData, +) from cyclops.evaluate.metrics.functional.roc import ( _binary_roc_compute, _multiclass_roc_compute, @@ -39,22 +42,22 @@ class BinaryROCCurve(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve" >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryROCCurve() >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -63,13 +66,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _binary_roc_compute( - state, - thresholds=self.thresholds, - pos_label=self.pos_label, + fpr_, tpr_, thresholds_ = _binary_roc_compute( + state, thresholds=self.thresholds, pos_label=self.pos_label ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MulticlassROCCurve( MulticlassPrecisionRecallCurve, @@ -101,11 +103,11 @@ class MulticlassROCCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.9, 0.1, 0]] >>> metric = MulticlassROCCurve(num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.33333333, 0.33333333, 1. ], - [0. , 0. , 0. , 1. ]]), array([[0. , 0.5, 0.5, 1. ], + [0. , 0. , 0. , 1. ]]), tpr=array([[0. , 0.5, 0.5, 1. ], [0. , 1. , 1. , 1. ], - [0. , 0. , 1. , 1. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 1. , 1. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [ @@ -115,26 +117,19 @@ class MulticlassROCCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0.25, 0.5 , 1. ], + ROCCurve(fpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0.25, 0.5 , 1. ]]), array([[0. , 0.25, 0.5 , 1. ], + [0. , 0.25, 0.5 , 1. ]]), tpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0. , 0. , 0. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 0. , 0. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -143,13 +138,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _multiclass_roc_compute( - state=state, - num_classes=self.num_classes, - thresholds=self.thresholds, + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( + state, thresholds=self.thresholds, num_classes=self.num_classes ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MultilabelROCCurve( MultilabelPrecisionRecallCurve, @@ -175,37 +169,30 @@ class MultilabelROCCurve( >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = MultilabelROCCurve(num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) - """ + """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -215,11 +202,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_roc_compute( + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( state=state, num_labels=self.num_labels, thresholds=self.thresholds, ) + return ROCCurveData(fpr_, tpr_, thresholds_) class ROCCurve(Metric, registry_key="roc_curve", force_register=True): @@ -258,14 +246,14 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = ROCCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import ROCCurve @@ -273,22 +261,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]] >>> metric = ROCCurve(task="multiclass", num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.5, 0.5, 1. ], - [0. , 0. , 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0. , 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [1, 2] >>> preds = [[[0.05, 0.75, 0.2]], [[0.1, 0.8, 0.1]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0., 0., 0., 1.], + ROCCurve(fpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([[0., 0., 0., 0.], + [0., 0., 0., 1.]]), tpr=array([[0., 0., 0., 0.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import ROCCurve @@ -296,22 +284,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = ROCCurve(task="multilabel", num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 diff --git a/cyclops/evaluate/metrics/utils.py b/cyclops/evaluate/metrics/utils.py index 3bff5b85a..86f7fc9c3 100644 --- a/cyclops/evaluate/metrics/utils.py +++ b/cyclops/evaluate/metrics/utils.py @@ -1,6 +1,16 @@ """Utility functions for metrics.""" -from typing import Any, Callable, List, Literal, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import numpy.typing as npt @@ -295,10 +305,15 @@ def _apply_function_recursively( """ data_type = type(data) - if isinstance(data, (list, tuple, set)): - return data_type( - [_apply_function_recursively(el, func, *args, **kwargs) for el in data], - ) + is_namedtuple_ = ( + isinstance(data, tuple) + and hasattr(data, "_asdict") + and hasattr(data, "_fields") + ) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple_ or is_sequence: + out = [_apply_function_recursively(el, func, *args, **kwargs) for el in data] + return data_type(*out) if is_namedtuple_ else data_type(out) if isinstance(data, Mapping): return data_type( { diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index 69aa612d8..255dee72a 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -8,7 +8,9 @@ import plotly.graph_objs as go from plotly.subplots import make_subplots -from cyclops.evaluate.metrics.experimental.functional import PRCurve, ROCCurve +from cyclops.evaluate.metrics.experimental.functional import PRCurve as PRCurveExp +from cyclops.evaluate.metrics.experimental.functional import ROCCurve as ROCCurveExp +from cyclops.evaluate.metrics.functional import PRCurve, ROCCurve from cyclops.report.plot.base import Plotter from cyclops.report.plot.utils import ( bar_plot, @@ -93,7 +95,7 @@ def _set_class_names(self, class_names: List[str]) -> None: def roc_curve( self, - roc_curve: ROCCurve, + roc_curve: Union[ROCCurve, ROCCurveExp], auroc: Optional[Union[float, List[float], npt.NDArray[np.float_]]] = None, title: Optional[str] = "ROC Curve", layout: Optional[go.Layout] = None, @@ -188,7 +190,7 @@ def roc_curve( def roc_curve_comparison( self, - roc_curves: Dict[str, ROCCurve], + roc_curves: Dict[str, Union[ROCCurve, ROCCurveExp]], aurocs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -293,7 +295,7 @@ def roc_curve_comparison( def precision_recall_curve( self, - precision_recall_curve: PRCurve, + precision_recall_curve: Union[PRCurve, PRCurveExp], title: Optional[str] = "Precision-Recall Curve", layout: Optional[go.Layout] = None, **plot_kwargs: Any, @@ -357,7 +359,7 @@ def precision_recall_curve( def precision_recall_curve_comparison( self, - precision_recall_curves: Dict[str, PRCurve], + precision_recall_curves: Dict[str, Union[PRCurve, PRCurveExp]], auprcs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None,