-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
122 lines (105 loc) · 4.32 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import threading
import queue
import multiprocessing
from collections import defaultdict
import jax
import jax.numpy as jnp
def make_batch(samples):
batch = {k:jnp.array(v) for k,v in samples.items()}
batch['labels'] = batch['input_ids'].copy()
return batch
class PrefetchDataloaderTread(threading.Thread):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
if shuffle:
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(dataset)
self.queue = queue.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def __next__(self):
batch = self.queue.get()
return batch
def run(self):
i = 0
while True and i < self.max_steps:
i += 1
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(make_batch(samples))
self.queue.put(None)
def __iter__(self):
return self
class PrefetchDataloader(multiprocessing.Process):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
self.make_iter()
self.queue = multiprocessing.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
return make_batch(self.queue.get())
def run(self):
i = 0
while True and i < self.max_steps:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
self.queue.put(None)
def __iter__(self):
return self