Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use namedtuple to store curve results (ROC, PR) for non-experimental … #574

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cyclops/evaluate/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@
recall,
)
from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # noqa: F401
PRCurve,
binary_precision_recall_curve,
multiclass_precision_recall_curve,
multilabel_precision_recall_curve,
precision_recall_curve,
)
from cyclops.evaluate.metrics.functional.roc import ( # noqa: F401
ROCCurve,
binary_roc_curve,
multiclass_roc_curve,
multilabel_roc_curve,
Expand Down
98 changes: 42 additions & 56 deletions cyclops/evaluate/metrics/functional/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for computing the precision-recall curve for different input types."""

from typing import Any, List, Literal, Optional, Tuple, Union
from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -15,6 +15,14 @@
)


class PRCurve(NamedTuple):
"""Named tuple with Precision-Recall curve (Precision, Recall and thresholds)."""

precision: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]
recall: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]
thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]


def _format_thresholds(
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
) -> Optional[npt.NDArray[np.float_]]:
Expand Down Expand Up @@ -279,7 +287,7 @@ def binary_precision_recall_curve(
preds: npt.ArrayLike,
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
pos_label: int = 1,
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
) -> PRCurve:
"""Compute precision-recall curve for binary input.

Parameters
Expand All @@ -301,13 +309,10 @@ def binary_precision_recall_curve(

Returns
-------
precision : numpy.ndarray
Precision scores such that element i is the precision of predictions
with score >= thresholds[i].
recall : numpy.ndarray
Recall scores in descending order.
thresholds : numpy.ndarray
Thresholds used for computing the precision and recall scores.
PRCurve
A named tuple containing the precision (element i is the precision of predictions
with score >= thresholds[i]), recall (scores in descending order)
and thresholds used to compute the precision-recall curve.

Examples
--------
Expand Down Expand Up @@ -335,13 +340,14 @@ def binary_precision_recall_curve(
thresholds = _format_thresholds(thresholds)

state = _binary_precision_recall_curve_update(target, preds, thresholds)

return _binary_precision_recall_curve_compute(
precision_, recall_, thresholds_ = _binary_precision_recall_curve_compute(
state,
thresholds,
pos_label=pos_label,
)

return PRCurve(precision_, recall_, thresholds_)


def _multiclass_precision_recall_curve_format(
target: npt.ArrayLike,
Expand Down Expand Up @@ -572,14 +578,7 @@ def multiclass_precision_recall_curve(
preds: npt.ArrayLike,
num_classes: int,
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
) -> Union[
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
Tuple[
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
],
]:
) -> PRCurve:
"""Compute the precision-recall curve for multiclass problems.

Parameters
Expand All @@ -600,18 +599,13 @@ def multiclass_precision_recall_curve(

Returns
-------
precision : numpy.ndarray or list of numpy.ndarray
Precision scores where element i is the precision score corresponding
to the threshold i. If state is a tuple of the target and predicted
probabilities, then precision is a list of arrays, where each array
corresponds to the precision scores for a class.
recall : numpy.ndarray or list of numpy.ndarray
Recall scores where element i is the recall score corresponding to
the threshold i. If state is a tuple of the target and predicted
probabilities, then recall is a list of arrays, where each array
corresponds to the recall scores for a class.
thresholds : numpy.ndarray or list of numpy.ndarray
Thresholds used for computing the precision and recall scores.
PRcurve
A named tuple containing the precision, recall, and thresholds.
Precision and recall are arrays where element i is the precision and
recall score corresponding to threshold i. If state is a tuple of the
target and predicted probabilities, then precision and recall are lists
of arrays, where each array corresponds to the precision and recall
scores for a class.

Examples
--------
Expand Down Expand Up @@ -652,11 +646,12 @@ def multiclass_precision_recall_curve(
thresholds=thresholds,
)

return _multiclass_precision_recall_curve_compute(
precision_, recall_, thresholds_ = _multiclass_precision_recall_curve_compute(
state,
thresholds, # type: ignore
num_classes,
)
return PRCurve(precision_, recall_, thresholds_)


def _multilabel_precision_recall_curve_format(
Expand Down Expand Up @@ -868,14 +863,7 @@ def multilabel_precision_recall_curve(
preds: npt.ArrayLike,
num_labels: int,
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
) -> Union[
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
Tuple[
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
],
]:
) -> PRCurve:
"""Compute the precision-recall curve for multilabel input.

Parameters
Expand All @@ -897,16 +885,18 @@ def multilabel_precision_recall_curve(

Returns
-------
precision : numpy.ndarray or List[numpy.ndarray]
PRCurve
A named tuple with the following:
- ``precision``: numpy.ndarray or List[numpy.ndarray].
Precision values for each label. If ``thresholds`` is None, then
precision is a list of arrays, one for each label. Otherwise,
precision is a single array with shape
(``num_labels``, len(``thresholds``)).
recall : numpy.ndarray or List[numpy.ndarray]
- ``recall``: numpy.ndarray or List[numpy.ndarray].
Recall values for each label. If ``thresholds`` is None, then
recall is a list of arrays, one for each label. Otherwise,
recall is a single array with shape (``num_labels``, len(``thresholds``)).
thresholds : numpy.ndarray or List[numpy.ndarray]
- ``thresholds``: numpy.ndarray or List[numpy.ndarray].
If ``thresholds`` is None, then thresholds is a list of arrays, one for
each label. Otherwise, thresholds is a single array with shape
(len(``thresholds``,).
Expand Down Expand Up @@ -950,11 +940,12 @@ def multilabel_precision_recall_curve(
thresholds=thresholds,
)

return _multilabel_precision_recall_curve_compute(
precision_, recall_, thresholds_ = _multilabel_precision_recall_curve_compute(
state,
thresholds, # type: ignore
num_labels,
)
return PRCurve(precision_, recall_, thresholds_)


def precision_recall_curve(
Expand All @@ -965,14 +956,7 @@ def precision_recall_curve(
pos_label: int = 1,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
) -> Union[
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
Tuple[
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
List[npt.NDArray[np.float_]],
],
]:
) -> PRCurve:
"""Compute the precision-recall curve for different tasks/input types.

Parameters
Expand All @@ -997,17 +981,19 @@ def precision_recall_curve(

Returns
-------
precision : numpy.ndarray
PRCurve
A named tuple with the following:
- ``precision``: numpy.ndarray or List[numpy.ndarray].
The precision scores where ``precision[i]`` is the precision score for
``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel',
``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel',
then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the
precision scores for class or label ``i``.
recall : numpy.ndarray
- ``recall``: numpy.ndarray or List[numpy.ndarray].
The recall scores where ``recall[i]`` is the recall score for ``scores >=
thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then
``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall
scores for class or label ``i``.
thresholds : numpy.ndarray
- ``thresholds``: numpy.ndarray or List[numpy.ndarray].
Thresholds used for computing the precision and recall scores.

Raises
Expand Down
Loading