Skip to content

Commit

Permalink
Filter disabled tokens after prompts tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Aug 1, 2024
1 parent ffa1153 commit 3b6b2b9
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace {

constexpr const char name[] = "cgemma.session";
constexpr const int gemma_unk_id = 3;
constexpr const int gemma_eot_id = 107;

std::vector<int> text2prompt(cgemma::session* sess, const char* text, size_t len) {
Expand Down Expand Up @@ -41,6 +42,11 @@ std::vector<int> text2prompt(cgemma::session* sess, const char* text, size_t len
if (!sess->inst()->model().Tokenizer().Encode(s, &prompt)) {
throw std::runtime_error("Tokenizer encoding failed. (text2prompt)");
}
if (!sess->inst()->disabled_tokens().empty()) {
std::replace_if(prompt.begin(), prompt.end(), [&](int token) {
return sess->inst()->disabled_tokens().find(token) != sess->inst()->disabled_tokens().end();
}, gemma_unk_id);
}
if (sess->pos() == 0) {
prompt.emplace(prompt.cbegin(), gcpp::BOS_ID);
}
Expand All @@ -53,9 +59,11 @@ void generate(cgemma::session* sess, const std::vector<int>& prompt, const gcpp:
cfg.verbosity = 0;
cfg.gen = &sess->rnd();
cfg.stream_token = stream_token;
cfg.accept_token = [&](int token, float) {
return sess->inst()->disabled_tokens().find(token) == sess->inst()->disabled_tokens().end();
};
if (!sess->inst()->disabled_tokens().empty()) {
cfg.accept_token = [&](int token, float) {
return sess->inst()->disabled_tokens().find(token) == sess->inst()->disabled_tokens().end();
};
}
sess->inst()->model().Generate(cfg, prompt, sess->pos(), sess->kv_cache(), sess->timing_info());
}

Expand Down

0 comments on commit 3b6b2b9

Please sign in to comment.