forked from ucfnlp/summarization-sing-pair-mix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_vocab.py
66 lines (44 loc) · 2.57 KB
/
make_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import collections
import util
from absl import app, flags
from tqdm import tqdm
import os
import glob
import data
import convert_data
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset_name', 'cnn_dm', 'Which dataset to use. Makes a log dir based on name.\
Must be one of {cnn_dm, xsum, duc_2004}')
flags.DEFINE_string('data_root', 'data/tf_data', 'Path to root directory for all datasets (already converted to TensorFlow examples).')
flags.DEFINE_string('dataset_split', 'all', 'Which dataset split to use. Must be one of {train, val, test, all}')
names_to_types = [('raw_article_sents', 'string_list'), ('article', 'string'), ('abstract', 'string_list'), ('doc_indices', 'string')]
VOCAB_SIZE = 200000
def main(unused_argv):
if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
raise Exception("Problem with flags: %s" % unused_argv)
if FLAGS.dataset_split == 'all':
dataset_splits = ['test', 'val', 'train']
else:
dataset_splits = [FLAGS.dataset_split]
vocab_counter = collections.Counter()
for dataset_split in dataset_splits:
source_dir = os.path.join(FLAGS.data_root, FLAGS.dataset_name)
source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))
total = len(source_files) * 1000
example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False,
should_check_valid=False)
for example_idx, example in enumerate(tqdm(example_generator, total=total)):
raw_article_sents, article, abstracts, doc_indices = util.unpack_tf_example(
example, names_to_types)
article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
# groundtruth_summ_sent_tokens = [sent.strip().split() for sent in groundtruth_summary_text.strip().split('\n')]
groundtruth_summ_sent_tokens = [[token for token in abstract.strip().split() if token not in ['<s>','</s>']] for abstract in abstracts]
all_tokens = util.flatten_list_of_lists(article_sent_tokens) + util.flatten_list_of_lists(groundtruth_summ_sent_tokens)
vocab_counter.update(all_tokens)
print("Writing vocab file...")
with open(os.path.join('logs', "vocab_" + FLAGS.dataset_name), 'w') as writer:
for word, count in vocab_counter.most_common(VOCAB_SIZE):
writer.write(word + ' ' + str(count) + '\n')
print("Finished writing vocab file")
if __name__ == '__main__':
app.run(main)