-
Notifications
You must be signed in to change notification settings - Fork 1
/
sentence_piece.py
executable file
·65 lines (55 loc) · 2.64 KB
/
sentence_piece.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-4.0
"""
Script for training, encoding, and decoding sub-word segmentation model
using SentencePiece.
"""
import sentencepiece as sp
import argparse
from contextlib import ExitStack
def train(sp_model, inputs, vocab_size, model_type):
sp.SentencePieceTrainer.Train('--input={} --model_prefix={} --vocab_size={} '
'--model_type={}'.format(inputs, sp_model, vocab_size, model_type))
def encode_decode(sp_model, inputs, outputs, mode=None):
sp_spp = sp.SentencePieceProcessor()
sp_spp.Load(sp_model)
if mode == "encode":
def preprocess(sample):
return " ".join(sp_spp.EncodeAsPieces(sample.strip()))
if mode == "decode":
def preprocess(sample):
return "".join(sp_spp.DecodePieces(sample.strip().split()))
with ExitStack() as stack:
infiles = [stack.enter_context(open(infile, "r", encoding="utf-8")) for infile in inputs]
outfiles = [stack.enter_context(open(outfile, "w", encoding="utf-8")) for outfile in outputs]
for s, samples in enumerate(zip(*infiles)):
processed_samples = list(map(preprocess, samples))
for sample, outfile in zip(processed_samples, outfiles):
print(sample, file=outfile)
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--run", choices=['train', 'encode', 'decode'], required=True,
help="SP run type.")
parser.add_argument("--sp_model", required=True,
help="SP model." )
parser.add_argument("--inputs", nargs="+", type=str, default="",
help="input files for SP run type.")
parser.add_argument("--vocab_size", required=False, default=3200,
help="number of SP segmentation rules.")
parser.add_argument("--model_type", required=False, default="bpe",
help="SP model type (bpe, unigram)." )
parser.add_argument("--outputs", nargs="+", required=False, default="",
help="files for SP run outputs.")
args = parser.parse_args()
if args.run == "train":
inputs = ",".join(args.inputs)
print(f"SentencePiece: {args.run} ...")
train(args.sp_model, inputs, args.vocab_size, args.model_type)
elif args.run == "encode" or args.run == "decode":
assert len(args.inputs) == len(args.outputs), "different input and output file size."
print(f"SentencePiece: {args.run} ...")
encode_decode(args.sp_model, args.inputs,args. outputs, args.run)
else:
raise ValueError(f"SP run mode {args.run} unknown.")
if __name__ == "__main__":
main()