Skip to content

Commit

Permalink
feat(tokenizers): Add decoding support to HFTokenizer
Browse files Browse the repository at this point in the history
pytorch#1251
Branch: TokenizersCpp-1251

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart committed Nov 15, 2024
1 parent 7084831 commit b8c7941
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
18 changes: 18 additions & 0 deletions tokenizer/hf_tokenizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ void HFTokenizer::load(const std::string& path) {
exit(EXIT_FAILURE);
}

// Set up the decoder (optional)
try {
_decoder = TokenDecoderConfig().parse_json(parsed_json.at("decoder")).create();
} catch (const json::out_of_range& e) {
// No decoder specified
}

// TODO: Do we need to parse the merges?

// If a tokenizer config file is found, parse it to look up the eos/bos tokens
Expand Down Expand Up @@ -241,3 +248,14 @@ void HFTokenizer::_encode(
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
}

void HFTokenizer::_decode(
re2::StringPiece input,
std::string& ret
) const {
if (_decoder) {
ret += _decoder->decode(input);
} else {
ret += input;
}
}
9 changes: 8 additions & 1 deletion tokenizer/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ void Tiktoken::_encode(
}
}

void Tiktoken::_decode(
re2::StringPiece input,
std::string& ret) const {
ret += input;
}

// -------------------------private method end-------------------------------
// -------------------------public method start-------------------------------

Expand Down Expand Up @@ -386,8 +392,9 @@ std::string BPETokenizerBase::decode(uint64_t prev, uint64_t cur) const {
exit(EXIT_FAILURE);
}
}
ret += token_bytes;
_decode(token_bytes, ret);

return ret;
}

// -------------------------public method end-------------------------------
28 changes: 22 additions & 6 deletions tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "pre_tokenizer.h"
#include "sentencepiece_processor.h"
#include "token_decoder.h"

class Tokenizer {
public:
Expand Down Expand Up @@ -125,10 +126,15 @@ class BPETokenizerBase : public Tokenizer {
Decoder special_token_decoder_;

private:

virtual void _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;

virtual void _decode(
re2::StringPiece input,
std::string& ret) const = 0;
};

class Tiktoken : public BPETokenizerBase {
Expand Down Expand Up @@ -160,9 +166,13 @@ class Tiktoken : public BPETokenizerBase {
}

void _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(
re2::StringPiece input,
std::string& ret) const override;

// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
const std::string _pattern =
Expand Down Expand Up @@ -194,10 +204,16 @@ class HFTokenizer : public BPETokenizerBase {
void load(const std::string& tokenizer_path) override;

private:

void _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(
re2::StringPiece input,
std::string& ret) const override;

PreTokenizer::Ptr _pretokenizer;
TokenDecoder::Ptr _decoder;
};

0 comments on commit b8c7941

Please sign in to comment.