-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmetrics.py
121 lines (97 loc) · 3.2 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
from sklearn.metrics import roc_auc_score
from torch import multiprocessing as mp
import torch
def get_row_indices(row, interactions):
start = interactions.indptr[row]
end = interactions.indptr[row + 1]
return interactions.indices[start:end]
def auc(model, interactions, num_workers=1):
aucs = []
processes = []
n_users = interactions.shape[0]
mp_batch = int(np.ceil(n_users / num_workers))
queue = mp.Queue()
rows = np.arange(n_users)
np.random.shuffle(rows)
for rank in range(num_workers):
start = rank * mp_batch
end = np.min((start + mp_batch, n_users))
p = mp.Process(target=batch_auc,
args=(queue, rows[start:end], interactions, model))
p.start()
processes.append(p)
while True:
is_alive = False
for p in processes:
if p.is_alive():
is_alive = True
break
if not is_alive and queue.empty():
break
while not queue.empty():
aucs.append(queue.get())
queue.close()
for p in processes:
p.join()
return np.mean(aucs)
def batch_auc(queue, rows, interactions, model):
n_items = interactions.shape[1]
items = torch.arange(0, n_items).long()
users_init = torch.ones(n_items).long()
for row in rows:
row = int(row)
users = users_init.fill_(row)
preds = model.predict(users, items)
actuals = get_row_indices(row, interactions)
if len(actuals) == 0:
continue
y_test = np.zeros(n_items)
y_test[actuals] = 1
queue.put(roc_auc_score(y_test, preds.data.numpy()))
def patk(model, interactions, num_workers=1, k=5):
patks = []
processes = []
n_users = interactions.shape[0]
mp_batch = int(np.ceil(n_users / num_workers))
queue = mp.Queue()
rows = np.arange(n_users)
np.random.shuffle(rows)
for rank in range(num_workers):
start = rank * mp_batch
end = np.min((start + mp_batch, n_users))
p = mp.Process(target=batch_patk,
args=(queue, rows[start:end], interactions, model),
kwargs={'k': k})
p.start()
processes.append(p)
while True:
is_alive = False
for p in processes:
if p.is_alive():
is_alive = True
break
if not is_alive and queue.empty():
break
while not queue.empty():
patks.append(queue.get())
queue.close()
for p in processes:
p.join()
return np.mean(patks)
def batch_patk(queue, rows, interactions, model, k=5):
n_items = interactions.shape[1]
items = torch.arange(0, n_items).long()
users_init = torch.ones(n_items).long()
for row in rows:
row = int(row)
users = users_init.fill_(row)
preds = model.predict(users, items)
actuals = get_row_indices(row, interactions)
if len(actuals) == 0:
continue
top_k = np.argpartition(-np.squeeze(preds.data.numpy()), k)
top_k = set(top_k[:k])
true_pids = set(actuals)
if true_pids:
queue.put(len(top_k & true_pids) / float(k))