forked from FranxYao/nlu-cw2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
108 lines (85 loc) · 4.7 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import collections
import logging
import os
import sys
import re
import pickle
from seq2seq import utils
from seq2seq.data.dictionary import Dictionary
SPACE_NORMALIZER = re.compile("\s+")
def word_tokenize(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
def get_args():
parser = argparse.ArgumentParser('Data pre-processing)')
parser.add_argument('--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('--target-lang', default=None, metavar='TGT', help='target language')
parser.add_argument('--train-prefix', default=None, metavar='FP', help='train file prefix')
parser.add_argument('--tiny-train-prefix', default=None, metavar='FP', help='tiny train file prefix')
parser.add_argument('--valid-prefix', default=None, metavar='FP', help='valid file prefix')
parser.add_argument('--test-prefix', default=None, metavar='FP', help='test file prefix')
parser.add_argument('--dest-dir', default='data-bin', metavar='DIR', help='destination dir')
parser.add_argument('--threshold-src', default=2, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--num-words-src', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--threshold-tgt', default=2, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--num-words-tgt', default=-1, type=int, help='number of target words to retain')
return parser.parse_args()
def main(args):
os.makedirs(args.dest_dir, exist_ok=True)
src_dict = build_dictionary([args.train_prefix + '.' + args.source_lang])
tgt_dict = build_dictionary([args.train_prefix + '.' + args.target_lang])
src_dict.finalize(threshold=args.threshold_src, num_words=args.num_words_src)
src_dict.save(os.path.join(args.dest_dir, 'dict.' + args.source_lang))
logging.info('Built a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict)))
tgt_dict.finalize(threshold=args.threshold_tgt, num_words=args.num_words_tgt)
tgt_dict.save(os.path.join(args.dest_dir, 'dict.' + args.target_lang))
logging.info('Built a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict)))
def make_split_datasets(lang, dictionary):
if args.train_prefix is not None:
make_binary_dataset(args.train_prefix + '.' + lang, os.path.join(args.dest_dir, 'train.' + lang),
dictionary)
if args.tiny_train_prefix is not None:
make_binary_dataset(args.tiny_train_prefix + '.' + lang, os.path.join(args.dest_dir, 'tiny_train.' + lang),
dictionary)
if args.valid_prefix is not None:
make_binary_dataset(args.valid_prefix + '.' + lang, os.path.join(args.dest_dir, 'valid.' + lang),
dictionary)
if args.test_prefix is not None:
make_binary_dataset(args.test_prefix + '.' + lang, os.path.join(args.dest_dir, 'test.' + lang), dictionary)
make_split_datasets(args.source_lang, src_dict)
make_split_datasets(args.target_lang, tgt_dict)
def build_dictionary(filenames, tokenize=word_tokenize):
dictionary = Dictionary()
for filename in filenames:
with open(filename, 'r') as file:
for line in file:
for symbol in word_tokenize(line.strip()):
dictionary.add_word(symbol)
dictionary.add_word(dictionary.eos_word)
return dictionary
def make_binary_dataset(input_file, output_file, dictionary, tokenize=word_tokenize, append_eos=True):
nsent, ntok = 0, 0
unk_counter = collections.Counter()
def unk_consumer(word, idx):
if idx == dictionary.unk_idx and word != dictionary.unk_word:
unk_counter.update([word])
tokens_list = []
with open(input_file, 'r') as inf:
for line in inf:
tokens = dictionary.binarize(line.strip(), word_tokenize, append_eos, consumer=unk_consumer)
nsent, ntok = nsent + 1, ntok + len(tokens)
tokens_list.append(tokens.numpy())
with open(output_file, 'wb') as outf:
pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL)
logging.info('Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'.format(
input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word))
if __name__ == '__main__':
args = get_args()
utils.init_logging(args)
logging.info('COMMAND: %s' % ' '.join(sys.argv))
logging.info('Arguments: {}'.format(vars(args)))
main(args)