-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathvocab.py
130 lines (105 loc) · 3.46 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Create a vocabulary wrapper
import pickle
from collections import Counter
from pycocotools.coco import COCO
import json
import argparse
import os
annotations = {
'mrw': ['mrw-v1.0.json'],
'tgif': ['tgif-v1.0.tsv'],
'coco': ['annotations/captions_train2014.json',
'annotations/captions_val2014.json'],
}
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.idx = 0
self.word2idx = {}
self.idx2word = {}
def add_word(self, word):
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if word not in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def from_tgif_tsv(path):
captions = [line.strip().split('\t')[1] \
for line in open(path, 'r').readlines()]
return captions
def from_mrw_json(path):
dataset = json.load(open(path, 'r'))
captions = []
for i, datum in enumerate(dataset):
cap = datum['sentence']
cap = cap.replace('/r/','')
cap = cap.replace('r/','')
cap = cap.replace('/u/','')
cap = cap.replace('u/','')
cap = cap.replace('..','')
cap = cap.replace('/',' ')
cap = cap.replace('-',' ')
captions += [cap]
return captions
def from_coco_json(path):
coco = COCO(path)
ids = coco.anns.keys()
captions = []
for i, idx in enumerate(ids):
captions.append(str(coco.anns[idx]['caption']))
return captions
def from_txt(txt):
captions = []
with open(txt, 'rb') as f:
for line in f:
captions.append(line.strip())
return captions
def build_vocab(data_path, data_name, jsons, threshold):
"""Build a simple vocabulary wrapper."""
import nltk
counter = Counter()
for path in jsons[data_name]:
full_path = os.path.join(os.path.join(data_path, data_name), path)
if data_name == 'tgif':
captions = from_tgif_tsv(full_path)
elif data_name == 'mrw':
captions = from_mrw_json(full_path)
elif data_name == 'coco':
captions = from_coco_json(full_path)
else:
captions = from_txt(full_path)
for i, caption in enumerate(captions):
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
# Discard if the occurrence of the word is less than min_word_cnt.
words = [word for word, cnt in counter.items() if cnt >= threshold]
print('Vocabulary size: {}'.format(len(words)))
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')
# Add words to the vocabulary.
for i, word in enumerate(words):
vocab.add_word(word)
return vocab
def main(data_path, data_name, threshold):
vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=threshold)
if not os.path.isdir('./vocab'):
os.makedirs('./vocab')
with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f:
pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--threshold', default=0, type=int)
parser.add_argument('--data_path', default=os.path.join(os.getcwd(),'data'))
parser.add_argument('--data_name', default='coco', help='coco|tgif|mrw')
opt = parser.parse_args()
main(opt.data_path, opt.data_name, opt.threshold)