-
Notifications
You must be signed in to change notification settings - Fork 1
/
sampler.py
executable file
·73 lines (60 loc) · 2.74 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import random
from multiprocessing import Process, Queue
class Sampler(object):
def sample_function(self, user_train, usernum, itemnum, batch_size, result_queue, SEED, negatives=None,
hard_rate=1.0):
def _sample_with_negatives():
user = np.random.randint(usernum)
while len(user_train[user]) == 0:
user = np.random.randint(usernum)
pos = random.sample(user_train[user], 1)[0]
if len(negatives[user]) == 0 or random.random() > hard_rate:
while True:
neg = np.random.randint(itemnum)
if neg not in user_train[user]:
break
else:
neg = random.sample(negatives[user], 1)[0]
return np.asarray([user, pos, neg], np.int32)
def _sample():
user = np.random.randint(usernum)
while len(user_train[user]) == 0:
user = np.random.randint(usernum)
pos = random.sample(user_train[user], 1)[0]
while True:
neg = np.random.randint(itemnum)
if neg not in user_train[user]:
break
return np.asarray([user, pos, neg], np.int32)
np.random.seed(SEED)
random.seed(SEED)
if negatives is None: sample = _sample
else: sample = _sample_with_negatives
while True:
one_batch = np.zeros([batch_size, 3], dtype=np.int32)
for i in range(batch_size):
one_batch[i, :] = sample()
result_queue.put(one_batch)
def __init__(self, User, usernum, itemnum, batch_size=10000, n_workers=1, negatives=None, hard_rate=1.0):
self.result_queue = Queue(maxsize=n_workers * 2)
self.processors = []
for i in range(n_workers):
self.processors.append(
Process(target=self.sample_function, args=(User,
usernum,
itemnum,
batch_size,
self.result_queue,
np.random.randint(2e9),
negatives,
hard_rate
)))
self.processors[-1].daemon = True
self.processors[-1].start()
def next_batch(self):
return self.result_queue.get()
def close(self):
for p in self.processors:
p.terminate()
p.join()