forked from CytAI/SRLOOD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
115 lines (86 loc) · 3.55 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import numpy as np
from utils import set_seed, collate_fn, pad0s, CDataLoader,CDataLoaderDev
from sklearn.metrics import roc_auc_score
def merge_keys(l, keys):
new_dict = {}
for key in keys:
new_dict[key] = []
for i in l:
new_dict[key] += i[key]
return new_dict
def evaluate_ood(args, model, features, ood, tag):
keys = ['softmax', 'maha', 'cosine', 'energy']
dataloader = CDataLoader(features, batch_size=args.batch_size, collate_fn=collate_fn,shuffle=False, drop_last=False)
in_scores = []
for batch in dataloader:
model.eval()
for key, value in batch.items():
if not 'SRL' in key.split('_'):
batch[key]=value.to(0)
else:
batch[key]=value
with torch.no_grad():
ood_keys = model.compute_ood(**batch)
in_scores.append(ood_keys)
in_scores = merge_keys(in_scores, keys)
dataloader = CDataLoader(ood, batch_size=args.batch_size, collate_fn=collate_fn,shuffle=False, drop_last=False)
out_scores = []
for batch in dataloader:
model.eval()
for key, value in batch.items():
if not ('SRL' in key.split('_')):
batch[key]=value.to(0)
else:
batch[key]=value
with torch.no_grad():
ood_keys = model.compute_ood(**batch)
out_scores.append(ood_keys)
out_scores = merge_keys(out_scores, keys)
outputs = {}
for key in keys:
ins = np.array(in_scores[key], dtype=np.float64)
outs = np.array(out_scores[key], dtype=np.float64)
inl = np.ones_like(ins).astype(np.int64)
outl = np.zeros_like(outs).astype(np.int64)
scores = np.concatenate([ins, outs], axis=0)
labels = np.concatenate([inl, outl], axis=0)
auroc, fpr_95 = get_auroc(labels, scores), get_fpr_95(labels, scores)
outputs[tag + "_" + key + "_auroc"] = auroc
outputs[tag + "_" + key + "_fpr95"] = fpr_95
return outputs
def get_auroc(key, prediction):
new_key = np.copy(key)
new_key[key == 0] = 0
new_key[key > 0] = 1
return roc_auc_score(new_key, prediction)
def get_fpr_95(key, prediction):
new_key = np.copy(key)
new_key[key == 0] = 0
new_key[key > 0] = 1
score = fpr_and_fdr_at_recall(new_key, prediction)
return score
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
out = np.cumsum(arr, dtype=np.float64)
expected = np.sum(arr, dtype=np.float64)
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
return out
def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=1.):
y_true = (y_true == pos_label)
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
tps = stable_cumsum(y_true)[threshold_idxs]
fps = 1 + threshold_idxs - tps
thresholds = y_score[threshold_idxs]
recall = tps / tps[-1]
last_ind = tps.searchsorted(tps[-1])
sl = slice(last_ind, None, -1)
recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
cutoff = np.argmin(np.abs(recall - recall_level))
return fps[cutoff] / (np.sum(np.logical_not(y_true)))