-
Notifications
You must be signed in to change notification settings - Fork 3
/
sampler.py
93 lines (75 loc) · 3.02 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
from multiprocessing import Process, Queue
import matplotlib as plt
def random_neq(l, r, s):
t = np.random.randint(l, r)
while t in s:
t = np.random.randint(l, r)
return t
def sample_function(user_train, Beh, Beh_w, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
def sample():
recency_alpha = 0.5
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
recency = np.zeros([maxlen], dtype=np.float32)
nxt = user_train[user][-1]
idx = maxlen - 1
ts = set(user_train[user])
for i in reversed(user_train[user][:-1]):
seq[idx] = i
pos[idx] = nxt
recency[idx] = recency_alpha**(maxlen-idx)
#print('recency[idx]...', recency[idx])
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
nxt = i
idx -= 1
if idx == -1: break
#print(abc)
seq_cxt = list()
pos_cxt = list()
pos_weight = list()
neg_weight = list()
for i in seq :
seq_cxt.append(Beh[(user,i)])
for i in pos :
pos_cxt.append(Beh[(user,i)])
for i in pos :
pos_weight.append(Beh_w[(user,i)])
neg_weight.append(1.0)
seq_cxt = np.asarray(seq_cxt)
pos_cxt = np.asarray(pos_cxt)
pos_weight = np.asarray(pos_weight)
return (user, seq, pos, neg, seq_cxt, pos_cxt, pos_weight, neg_weight , recency)
np.random.seed(SEED)
while True:
one_batch = []
for i in range(batch_size):
one_batch.append(sample())
result_queue.put(zip(*one_batch))
class WarpSampler(object):
def __init__(self, User, Beh, Beh_w, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
self.result_queue = Queue(maxsize=n_workers * 10)
self.processors = []
for i in range(n_workers):
self.processors.append(
Process(target=sample_function, args=(User,
Beh,
Beh_w,
usernum,
itemnum,
batch_size,
maxlen,
self.result_queue,
np.random.randint(2e9)
)))
self.processors[-1].daemon = True
self.processors[-1].start()
def next_batch(self):
return self.result_queue.get()
def close(self):
for p in self.processors:
p.terminate()
p.join()