-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_preparation.py
131 lines (108 loc) · 4.69 KB
/
data_preparation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import logging
from utils import load_embeddings, load_examples
from utils import collate_fn, collate_x
from per_class_dataset import *
import pickle as pkl
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
def truncate_examples(examples, n):
m = n // 2
truncated_examples = []
for (CL, word, CR), label in examples:
truncated_examples.append( ((CL[-m:], word, CR[:m]), word) )
return truncated_examples
def prepare_data(dataset,
embeddings,
vectorizer,
n=15,
ratio=.8,
use_gpu=False,
k=1,
data_augmentation=False,
over_population_threshold=100,
relative_over_population=True,
debug_mode=False,
verbose=True,
):
# Train-validation part
path = './data/' + dataset.dataset_name + '/examples/'
if data_augmentation:
examples = load_examples(path+'augmented_examples_topn5_cos_sim0.6.pkl')
else:
examples = load_examples(path + 'examples.pkl')
if debug_mode:
examples = list(examples)[:128]
examples = truncate_examples(examples, n)
transform = vectorizer.vectorize_unknown_example
def target_transform(y):
return embeddings[y]
train_valid_dataset = PerClassDataset(
examples,
transform=transform,
target_transform=target_transform,
)
train_dataset, valid_dataset = train_valid_dataset.split(
ratio=.8, shuffle=True, reuse_label_mappings=False)
filter_labels_cond = None
if over_population_threshold != None:
if relative_over_population:
over_population_threshold = int(
train_valid_dataset.stats()['most common labels number of examples'] / over_population_threshold)
def filter_labels_cond(label, N):
return N <= over_population_threshold
train_loader = PerClassLoader(dataset=train_dataset,
collate_fn=collate_fn,
batch_size=64,
k=k,
use_gpu=use_gpu,
filter_labels_cond=filter_labels_cond)
valid_loader = PerClassLoader(dataset=valid_dataset,
collate_fn=collate_fn,
batch_size=64,
k=k,
use_gpu=use_gpu,
filter_labels_cond=filter_labels_cond)
# Test part
test_examples = load_examples(path + 'valid_test_examples.pkl')
test_examples = truncate_examples(test_examples, n)
test_dataset = PerClassDataset(dataset=test_examples,
transform=transform)
test_loader = PerClassLoader(dataset=test_dataset,
collate_fn=collate_x,
k=-1,
shuffle=False,
batch_size=64,
use_gpu=use_gpu)
# OOV part
oov_examples = load_examples(path + 'oov_examples.pkl')
oov_examples = truncate_examples(oov_examples, n)
oov_dataset = PerClassDataset(dataset=oov_examples,
transform=transform)
oov_loader = PerClassLoader(dataset=oov_dataset,
collate_fn=collate_x,
k=-1,
shuffle=False,
batch_size=64,
use_gpu=use_gpu)
if verbose:
logging.info('Number of unique examples: {}'.format(len(examples)))
logging.info('\nGlobal statistics:')
stats = train_valid_dataset.stats()
for stats, value in stats.items():
logging.info(stats + ': ' + str(value))
logging.info('\nStatistics on the training dataset:')
stats = train_dataset.stats(over_population_threshold)
for stats, value in stats.items():
logging.info(stats + ': ' + str(value))
logging.info('\nStatistics on the validation dataset:')
stats = valid_dataset.stats(over_population_threshold)
for stats, value in stats.items():
logging.info(stats + ': ' + str(value))
logging.info('\nStatistics on the test dataset:')
stats = test_dataset.stats()
for stats, value in stats.items():
logging.info(stats + ': ' + str(value))
logging.info('\nFor training, loading ' + str(k) +
' examples per label per epoch.')
return train_loader, valid_loader, test_loader, oov_loader