-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
35 lines (31 loc) · 896 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
def parse_tensors(objs, target_list, predicate):
"""
"""
if predicate(objs):
target_list.append(objs)
elif isinstance(objs, (list, tuple)):
for o in objs:
parse_tensors(o, target_list, predicate)
elif isinstance(objs, dict):
for v in objs.values():
parse_tensors(v, target_list, predicate)
def batch_wrapper(iterator, batch_size, transform, shuffle=True):
r = []
while True:
index = list(range(len(iterator)))
if shuffle:
np.random.shuffle(index)
for i in index:
if len(r) == batch_size:
res = []
for j in range(len(r[0])):
a = [np.expand_dims(_r[j], 0) for _r in r]
res.append(np.concatenate(a, axis=0))
yield res
r = []
else:
sample = list(iterator[i])
if transform:
sample[0] = transform(sample[0])
r.append(sample)