-
Notifications
You must be signed in to change notification settings - Fork 70
/
SentencePiece.cc
146 lines (124 loc) · 4.24 KB
/
SentencePiece.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include "onmt/SentencePiece.h"
#include <sentencepiece_processor.h>
#include <stdexcept>
#include "Utils.h"
namespace onmt
{
static const std::string sp_marker("▁");
static inline void load_model(sentencepiece::SentencePieceProcessor& processor,
const std::string& model_path)
{
auto status = processor.Load(model_path);
if (!status.ok())
throw std::invalid_argument("Unable to open SentencePiece model " + model_path);
}
SentencePiece::SentencePiece(const std::string& model_path)
: _processor(new sentencepiece::SentencePieceProcessor())
, _nbest_size(0)
, _alpha(0.0)
{
load_model(*_processor, model_path);
}
SentencePiece::SentencePiece(const std::string& model_path, int nbest_size, float alpha)
: _processor(new sentencepiece::SentencePieceProcessor())
, _nbest_size(nbest_size)
, _alpha(alpha)
{
load_model(*_processor, model_path);
}
SentencePiece::~SentencePiece() = default;
void SentencePiece::update_tokenization_options(Tokenizer::Options& options) const
{
// Maybe enable SentencePiece compatibility mode.
if (options.mode == Tokenizer::Mode::None
&& !options.joiner_annotate
&& !options.spacer_annotate)
{
options.spacer_annotate = true;
options.no_substitution = true;
}
}
void SentencePiece::set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* options)
{
if (options && (options->joiner_annotate || options->spacer_new))
throw std::invalid_argument("SentencePiece vocabulary restriction requires the tokenization "
"to use \"spacer_annotate\" (same as spm_encode)");
auto status = _processor->SetVocabulary(vocabulary);
if (!status.ok())
throw std::invalid_argument(status.ToString());
}
void SentencePiece::reset_vocabulary()
{
_processor->ResetVocabulary();
}
void SentencePiece::enable_regularization(int nbest_size, float alpha)
{
_nbest_size = nbest_size;
_alpha = alpha;
}
std::vector<std::string> SentencePiece::encode(const std::string& str, bool training) const
{
std::vector<std::string> pieces;
if (training && _nbest_size != 0)
_processor->SampleEncode(str, _nbest_size, _alpha, &pieces);
else
_processor->Encode(str, &pieces);
return pieces;
}
std::vector<Token> SentencePiece::encode_and_annotate(const Token& token, bool training) const
{
std::vector<std::string> pieces = encode(token.surface, training);
// SentencePiece sometimes returns no pieces for a non empty input. In this case
// we simply return the original token.
if (pieces.empty())
return std::vector<Token>(1, token);
std::vector<Token> tokens;
tokens.reserve(pieces.size());
bool apply_spacer_on_next = false;
for (auto& piece : pieces)
{
// Prefixed by the spacer.
if (starts_with(piece, sp_marker))
{
if (piece.length() == sp_marker.length()) // Piece is just the spacer.
{
// Skip this isolated spacer and mark the next piece with the spacer flag.
apply_spacer_on_next = true;
continue;
}
else
{
Token sub_token(piece.substr(sp_marker.length()));
sub_token.spacer = true;
tokens.emplace_back(std::move(sub_token));
}
}
else
{
Token sub_token(std::move(piece));
if (apply_spacer_on_next)
{
sub_token.spacer = true;
sub_token.preserve = true; // The spacer was not attached to this piece so preserve it.
apply_spacer_on_next = false;
}
else if (!tokens.empty())
{
sub_token.join_left = true; // No spacer means it should be joined with the previous subtoken.
}
tokens.emplace_back(std::move(sub_token));
}
}
auto& first = tokens.front();
auto& last = tokens.back();
first.join_left = token.join_left;
last.join_right = token.join_right;
if (token.join_left && token.preserve)
first.preserve = true;
if (token.join_right && token.preserve)
last.preserve = true;
propagate_token_properties(token, tokens);
return tokens;
}
}