forked from facebookresearch/XLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate.py
150 lines (119 loc) · 5.73 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Translate sentences from the input stream.
# The model will be faster is sentences are sorted by length.
# Input sentences must have the same tokenization and BPE codes than the ones used in the model.
#
# Usage:
# cat source_sentences.bpe | \
# python translate.py --exp_name translate \
# --src_lang en --tgt_lang fr \
# --model_path trained_model.pth --output_path output
#
import os
import io
import sys
import argparse
import torch
from src.utils import AttrDict
from src.utils import bool_flag, initialize_exp
from src.data.dictionary import Dictionary
from src.model.transformer import TransformerModel
from src.fp16 import network_to_half
def get_parser():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")
# main parameters
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--fp16", type=bool_flag, default=False, help="Run model with float16")
parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
# model / output paths
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument("--output_path", type=str, default="", help="Output path")
# parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
# parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
# source language / target language
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
return parser
def main(params):
# initialize the experiment
logger = initialize_exp(params)
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
reloaded = torch.load(params.model_path)
model_params = AttrDict(reloaded['params'])
logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
# update dictionary parameters
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
setattr(params, name, getattr(model_params, name))
# build dictionary / build encoder / build decoder / reload weights
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
encoder.load_state_dict(reloaded['encoder'])
decoder.load_state_dict(reloaded['decoder'])
params.src_id = model_params.lang2id[params.src_lang]
params.tgt_id = model_params.lang2id[params.tgt_lang]
# float16
if params.fp16:
assert torch.backends.cudnn.enabled
encoder = network_to_half(encoder)
decoder = network_to_half(decoder)
# read sentences from stdin
src_sent = []
for line in sys.stdin.readlines():
assert len(line.strip().split()) > 0
src_sent.append(line)
logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))
f = io.open(params.output_path, 'w', encoding='utf-8')
for i in range(0, len(src_sent), params.batch_size):
# prepare batch
word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
for s in src_sent[i:i + params.batch_size]]
lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
batch[0] = params.eos_index
for j, s in enumerate(word_ids):
if lengths[j] > 2: # if sentence not empty
batch[1:lengths[j] - 1, j].copy_(s)
batch[lengths[j] - 1, j] = params.eos_index
langs = batch.clone().fill_(params.src_id)
# encode source batch and translate it
encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
encoded = encoded.transpose(0, 1)
decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
# convert sentences to words
for j in range(decoded.size(1)):
# remove delimiters
sent = decoded[:, j]
delimiters = (sent == params.eos_index).nonzero().view(-1)
assert len(delimiters) >= 1 and delimiters[0].item() == 0
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
# output translation
source = src_sent[i + j].strip()
target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
f.write(target + "\n")
f.close()
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
# check parameters
assert os.path.isfile(params.model_path)
assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
assert params.output_path and not os.path.isfile(params.output_path)
# translate
with torch.no_grad():
main(params)