-
Notifications
You must be signed in to change notification settings - Fork 3
/
clf_metrics.py
53 lines (49 loc) · 1.64 KB
/
clf_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Dict
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
class ClfMetric():
def __init__(self, threshold: float = 0.5):
"""
:param threshold: The threshold for classification
"""
super().__init__()
self.threshold = threshold
def eval(self, probs: np.array, labels: np.array) -> Dict:
"""
:param probs: the prob of each class for each sample
:param labels: true labels
:return: precision, recall, f1, roc_auc, prc_auc
"""
# calculate TPR, FPR, TNR, FNR & AUC
assert probs.shape[0] == labels.shape[0]
preds = (probs[:, 1] > self.threshold).astype(int)
TP, FP, TN, FN = 0, 0, 0, 0
for pred, label in zip(preds, labels):
if pred == label:
if pred == 1:
TP += 1
else:
TN += 1
else:
if pred == 1:
FP += 1
else:
FN += 1
roc_auc = roc_auc_score(labels, probs[:, -1])
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = (2 * TP) / (2 * TP + FP + FN)
pres, recs, thres = precision_recall_curve(labels, probs[:, -1])
prc_auc = auc(recs, pres)
return {
"precision": precision,
"recall": recall,
"f1": f1,
"roc_auc": roc_auc,
"prc_auc": prc_auc,
"TP": TP,
"FP": FP,
"TN": TN,
"FN": FN,
'score': recall * recall * (TP + TN + FP + FN) / (TP + FP)
}