Skip to content

Commit

Permalink
Update ignite/contrib/metrics/roc_auc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Jul 7, 2023
1 parent 9519e9d commit b919851
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ignite/contrib/metrics/roc_auc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, cast, Tuple, Union

import torch
import numpy as np

from ignite import distributed as idist
from ignite.exceptions import NotComputableError
Expand All @@ -20,7 +21,12 @@ def roc_auc_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) ->

y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return roc_curve(y_true, y_pred)
fpr, tpr, thresholds = roc_curve(y_true, y_pred)

# Replace any 'inf' values in the thresholds array with a large finite number
thresholds = np.where(np.isinf(thresholds), np.nanmax(thresholds[np.isfinite(thresholds)]), thresholds)

return fpr, tpr, thresholds


class ROC_AUC(EpochMetric):
Expand Down Expand Up @@ -192,3 +198,4 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

return fpr, tpr, thresholds

0 comments on commit b919851

Please sign in to comment.