-
Notifications
You must be signed in to change notification settings - Fork 6
/
data.py
104 lines (90 loc) · 3.4 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from preprocess_utils.utils import *
from tqdm import tqdm
def fix_parent(p, start_i):
p -= start_i
return 0 if p < 0 else p
def data_gen(data, split_size):
for sample in data:
accum_n = []
accum_t = []
accum_p = []
start_i = 0
for i, item in enumerate(zip(*sample)):
n, t, p = item
p = fix_parent(p, start_i)
accum_n.append(n)
accum_t.append(t)
accum_p.append(p)
if len(accum_n) == split_size:
yield accum_n, accum_t, accum_p
accum_n = []
accum_t = []
accum_p = []
start_i = i
if len(accum_n) > 0:
yield accum_n, accum_t, accum_p
class MainDataset(torch.utils.data.Dataset):
def __init__(self,
N_filename = './pickle_data/PY_non_terminal_small.pickle',
T_filename = './pickle_data/PY_terminal_10k_whole.pickle',
is_train=False,
truncate_size=150
):
super(MainDataset).__init__()
train_dataN, test_dataN, vocab_sizeN, train_dataT, test_dataT, vocab_sizeT, attn_size, train_dataP, test_dataP = input_data(
N_filename, T_filename
)
self.is_train = is_train
if self.is_train:
self.data = [item for item in data_gen(zip(tqdm(train_dataN), train_dataT, train_dataP), truncate_size)]
else:
self.data = [item for item in data_gen(zip(tqdm(test_dataN), test_dataT, test_dataP), truncate_size)]
self.data = sorted(self.data, key=lambda x: len(x[0]))
self.vocab_sizeN = vocab_sizeN
self.vocab_sizeT = vocab_sizeT
self.attn_size = attn_size
self.eof_N_id = vocab_sizeN - 1
self.eof_T_id = vocab_sizeT - 1
self.unk_id = vocab_sizeT - 2
self.truncate_size = truncate_size
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
return item
def collate_fn(self, samples, device='cpu'):
sent_N = [sample[0] for sample in samples]
sent_T = [sample[1] for sample in samples]
sent_P = [sample[2] for sample in samples]
s_max_length = max(map(lambda x: len(x), sent_N))
sent_N_tensors = []
sent_T_tensors = []
sent_P_tensors = []
for sn, st, sp in zip(sent_N, sent_T, sent_P):
sn_tensor = torch.ones(
s_max_length
, dtype=torch.long
, device=device
) * self.eof_N_id
st_tensor = torch.ones(
s_max_length
, dtype=torch.long
, device=device
) * self.eof_T_id
sp_tensor = torch.ones(
s_max_length
, dtype=torch.long
, device=device
) * 1
for idx, w in enumerate(sn):
sn_tensor[idx] = w
st_tensor[idx] = st[idx]
sp_tensor[idx] = sp[idx]
sent_N_tensors.append(sn_tensor.unsqueeze(0))
sent_T_tensors.append(st_tensor.unsqueeze(0))
sent_P_tensors.append(sp_tensor.unsqueeze(0))
sent_N_tensors = torch.cat(sent_N_tensors, dim=0)
sent_T_tensors = torch.cat(sent_T_tensors, dim=0)
sent_P_tensors = torch.cat(sent_P_tensors, dim=0)
return sent_N_tensors, sent_T_tensors, sent_P_tensors