-
Notifications
You must be signed in to change notification settings - Fork 0
/
coco_questions.py
76 lines (62 loc) · 2.7 KB
/
coco_questions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
import pandas as pd
import logging
import numpy as np
from mxnet_vqa.data.glove import glove_word2emb_300
from mxnet_vqa.utils.text_utils import word_tokenize, pad_sequences
def get_questions(data_dir_path, max_lines_retrieved=-1, split='val'):
if split == 'train':
data_path = os.path.join(data_dir_path, 'data/train_qa')
elif split == 'val':
data_path = os.path.join(data_dir_path, 'data/val_qa')
else:
print('Invalid split!')
sys.exit()
df = pd.read_pickle(data_path)
questions = df[['question']].values.tolist()
if max_lines_retrieved == -1:
return questions
return questions[:min(max_lines_retrieved, len(questions))]
def get_questions_matrix(data_dir_path, max_lines_retrieved=-1, split='val', mode='concat', max_sequence_length=-1):
if mode == 'add':
return get_questions_matrix_sum(data_dir_path, max_lines_retrieved, split)
else:
return get_questions_matrix_concat(data_dir_path, max_lines_retrieved, split,
max_sequence_length=max_sequence_length)
def get_questions_matrix_concat(data_dir_path, max_lines_retrieved=-1, split='val', max_sequence_length=-1):
questions = get_questions(data_dir_path, max_lines_retrieved, split)
glove_word2emb = glove_word2emb_300(data_dir_path)
logging.debug('glove: %d words loaded', len(glove_word2emb))
seq_list = []
for i, question in enumerate(questions):
words = word_tokenize(question[0].lower())
seq = []
for word in words:
emb = np.zeros(shape=300)
if word in glove_word2emb:
emb = glove_word2emb[word]
seq.append(emb)
if (i + 1) % 10000 == 0:
logging.debug('loaded %d questions', i + 1)
seq_list.append(seq)
question_matrix = pad_sequences(seq_list, max_sequence_length=max_sequence_length)
return question_matrix
def get_questions_matrix_sum(data_dir_path, max_lines_retrieved=-1, split='val'):
questions = get_questions(data_dir_path, max_lines_retrieved, split)
glove_word2emb = glove_word2emb_300(data_dir_path)
logging.debug('glove: %d words loaded', len(glove_word2emb))
seq_list = []
for i, question in enumerate(questions):
words = word_tokenize(question[0].lower())
E = np.zeros(shape=(300, len(words)))
for j, word in enumerate(words):
if word in glove_word2emb:
emb = glove_word2emb[word]
E[:, j] = emb
E = np.sum(E, axis=1)
if (i + 1) % 10000 == 0:
logging.debug('loaded %d questions', i + 1)
seq_list.append(E)
question_matrix = np.array(seq_list)
return question_matrix