-
Notifications
You must be signed in to change notification settings - Fork 18
/
sampler.py
86 lines (69 loc) · 2.86 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import random
import numpy as np
from multiprocessing import Process, Queue
def random_neq(l, r, s):
t = np.random.randint(l, r)
while t in s:
t = np.random.randint(l, r)
return t
def sample_function(user_train, usernum, itemnum, batch_size, maxlen,
threshold_user, threshold_item,
result_queue, SEED):
def sample():
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
nxt = user_train[user][-1]
idx = maxlen - 1
ts = set(user_train[user])
for i in reversed(user_train[user][:-1]):
#seq[idx] = i
# SSE for user side (2 lines)
if random.random() > threshold_item:
i = np.random.randint(1, itemnum + 1)
nxt = np.random.randint(1, itemnum + 1)
seq[idx] = i
pos[idx] = nxt
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
nxt = i
idx -= 1
if idx == -1: break
# SSE for item side (2 lines)
if random.random() > threshold_user:
user = np.random.randint(1, usernum + 1)
# equivalent to hard parameter sharing
#user = 1
return (user, seq, pos, neg)
np.random.seed(SEED)
while True:
one_batch = []
for i in range(batch_size):
one_batch.append(sample())
result_queue.put(zip(*one_batch))
class WarpSampler(object):
def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10,
threshold_user=1.0, threshold_item=1.0, n_workers=1):
self.result_queue = Queue(maxsize=n_workers * 10)
self.processors = []
for i in range(n_workers):
self.processors.append(
Process(target=sample_function, args=(User,
usernum,
itemnum,
batch_size,
maxlen,
threshold_user,
threshold_item,
self.result_queue,
np.random.randint(2e9)
)))
self.processors[-1].daemon = True
self.processors[-1].start()
def next_batch(self):
return self.result_queue.get()
def close(self):
for p in self.processors:
p.terminate()
p.join()