-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_json_file.py
57 lines (46 loc) · 1.93 KB
/
gen_json_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-4.0
"""
Generate JSON formatted input samples with optional prefix such as isometric translation flag,
for translation using Sockeye models.
"""
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', required=True,
help='Source plain file to generate json format for Sockeye decoing.')
parser.add_argument('--output', '-o', required=True,
help='Target json file.')
parser.add_argument('--src-prefix', '-sp', required=False, default="",
help='Source prefix.')
parser.add_argument('--tgt-prefix', '-tp', required=False, default="",
help='Target prefix.')
parser.add_argument('--keep-tgt-prefix', action='store_true', default=False,
help='If true, keeps the tgt-prefix.')
def samples_to_dict(input_f, src_p, tgt_p, keep_tgtp):
with open(input_f, 'r') as input_fo:
samples = []
for sample_ in input_fo:
sample = sample_.strip()
sample_dict = {"text": sample}
if src_p != "":
sample_dict["source_prefix"] = src_p
if tgt_p != "":
sample_dict["target_prefix:"] = tgt_p
sample_dict["keep_target_prefix"] = keep_tgtp
samples.append(sample_dict)
return samples
def make_json(samples, output):
with open(output, 'w') as out_fo:
for sample_ in samples:
json.dump(sample_, out_fo, ensure_ascii=False)
out_fo.write("\n")
if __name__ == '__main__':
args = parser.parse_args()
src_prefix = args.src_prefix; tgt_prefix = args.tgt_prefix
if args.keep_tgt_prefix:
keep_tgtp = "true" # json boolean
else:
keep_tgtp = "false"
samples = samples_to_dict(args.input, src_prefix, tgt_prefix, keep_tgtp)
make_json(samples, args.output)