Skip to content

Commit

Permalink
Simplify denominator check
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-nml committed Aug 1, 2024
1 parent 1b7933f commit 1368567
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,11 +782,9 @@ def estimate_f1(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union[pd.Ser
fp = np.where(y_pred == 1, 1 - y_pred_proba, 0)
fn = np.where(y_pred == 0, y_pred_proba, 0)
TP, FP, FN = np.sum(tp), np.sum(fp), np.sum(fn)
if TP + 0.5 * (FP + FN) == 0:
metric = 0
else:
metric = TP / (TP + 0.5 * (FP + FN))
return metric

denominator = TP + 0.5 * (FP + FN)
return TP / denominator if denominator != 0 else 0


@MetricFactory.register('precision', ProblemType.CLASSIFICATION_BINARY)
Expand Down Expand Up @@ -929,11 +927,9 @@ def estimate_precision(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union
tp = np.where(y_pred == 1, y_pred_proba, 0)
fp = np.where(y_pred == 1, 1 - y_pred_proba, 0)
TP, FP = np.sum(tp), np.sum(fp)
if TP + FP == 0:
metric = 0
else:
metric = TP / (TP + FP)
return metric

denominator = TP + FP
return TP / denominator if denominator != 0 else 0


@MetricFactory.register('recall', ProblemType.CLASSIFICATION_BINARY)
Expand Down Expand Up @@ -1076,11 +1072,9 @@ def estimate_recall(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union[pd
tp = np.where(y_pred == 1, y_pred_proba, 0)
fn = np.where(y_pred == 0, y_pred_proba, 0)
TP, FN = np.sum(tp), np.sum(fn)
if TP + FN == 0:
metric = 0
else:
metric = TP / (TP + FN)
return metric

denominator = TP + FN
return TP / denominator if denominator != 0 else 0


@MetricFactory.register('specificity', ProblemType.CLASSIFICATION_BINARY)
Expand Down Expand Up @@ -1215,11 +1209,9 @@ def estimate_specificity(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Uni
tn = np.where(y_pred == 0, 1 - y_pred_proba, 0)
fp = np.where(y_pred == 1, 1 - y_pred_proba, 0)
TN, FP = np.sum(tn), np.sum(fp)
if TN + FP == 0:
metric = 0
else:
metric = TN / (TN + FP)
return metric

denominator = TN + FP
return TN / denominator if denominator != 0 else 0


@MetricFactory.register('accuracy', ProblemType.CLASSIFICATION_BINARY)
Expand Down

0 comments on commit 1368567

Please sign in to comment.