Skip to content

Commit

Permalink
Adapt to the internal change about position calculation of gemma.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Oct 11, 2024
1 parent 9647ea7 commit 8a1dda3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 33 deletions.
30 changes: 14 additions & 16 deletions src/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class batch_data_holder {
kv_caches_.reserve(sess_ctxs.size());
for (auto& ctx: sess_ctxs) {
prompts_.emplace_back(ctx.prompt.data(), ctx.prompt.size());
start_pos_.emplace_back(ctx.start_pos);
start_pos_.emplace_back(ctx.start_pos > 0 ? ctx.start_pos + 1 : 0);
kv_caches_.emplace_back(std::move(ctx.sess->kv_cache()));
}
}
Expand Down Expand Up @@ -178,8 +178,8 @@ int batch(lua_State* L) {
cfg.batch_stream_token = [&](size_t query_idx, size_t pos, int token, float) {
auto& ctx = sess_ctxs[query_idx];
if (ctx.stream_fn == 0) {
if (pos - ctx.start_pos >= ctx.prompt.size() && token != gcpp::EOS_ID) {
if (inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == EOT_ID) {
if (pos - ctx.start_pos >= ctx.prompt.size()) {
if (token == gcpp::EOS_ID || inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == EOT_ID) {
return false;
}
ctx.output.push_back(token);
Expand All @@ -189,20 +189,18 @@ int batch(lua_State* L) {
} else {
auto eot = false;
lua_pushvalue(L, ctx.stream_fn);
if (pos - ctx.start_pos >= ctx.prompt.size() && token != gcpp::EOS_ID) {
if (inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
ctx.output.front() = token;
std::string token_text;
if (!inst->model().Tokenizer().Decode(ctx.output, &token_text)) {
throw std::runtime_error("Tokenizer decoding failed. (batch stream_mode)");
}
lua_pushlstring(L, token_text.data(), token_text.size());
}
} else {
if (pos - ctx.start_pos < ctx.prompt.size()) {
lua_pushnil(L);
} else if (token == gcpp::EOS_ID || inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
ctx.output.front() = token;
std::string token_text;
if (!inst->model().Tokenizer().Decode(ctx.output, &token_text)) {
throw std::runtime_error("Tokenizer decoding failed. (batch stream_mode)");
}
lua_pushlstring(L, token_text.data(), token_text.size());
}
lua_pushinteger(L, pos - ctx.start_pos);
lua_pushinteger(L, ctx.prompt.size());
Expand Down
33 changes: 16 additions & 17 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ void generate(cgemma::session* sess, const std::vector<int>& prompt, const gcpp:
};
}
cfg.image_tokens = sess->image_tokens();
size_t start_pos = sess->pos() > 0 ? sess->pos() + 1 : 0;
if (cfg.image_tokens) {
std::vector<int> image_prompt;
image_prompt.reserve(cfg.image_tokens->BatchSize() + prompt.size());
image_prompt.resize(cfg.image_tokens->BatchSize(), cgemma::PAD_ID);
image_prompt.insert(image_prompt.cend(), prompt.cbegin(), prompt.cend());
cfg.prefill_tbatch_size = image_prompt.size();
sess->inst()->model().Generate(cfg, gcpp::PromptTokens(image_prompt.data(), image_prompt.size()), sess->pos(), image_prompt.size(), sess->kv_cache(), sess->timing_info());
sess->inst()->model().Generate(cfg, gcpp::PromptTokens(image_prompt.data(), image_prompt.size()), start_pos, image_prompt.size(), sess->kv_cache(), sess->timing_info());
} else {
sess->inst()->model().Generate(cfg, gcpp::PromptTokens(prompt.data(), prompt.size()), sess->pos(), sess->kv_cache(), sess->timing_info());
sess->inst()->model().Generate(cfg, gcpp::PromptTokens(prompt.data(), prompt.size()), start_pos, sess->kv_cache(), sess->timing_info());
}
}

Expand All @@ -49,20 +50,18 @@ int stream_mode(lua_State* L, cgemma::session* sess, const std::vector<int>& pro
generate(sess, prompt, [&](size_t, size_t pos, int token, float) {
auto eot = false;
lua_pushvalue(L, 3);
if (pos - start_pos >= prompt_size && token != gcpp::EOS_ID) {
if (sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
output.front() = token;
std::string token_text;
if (!sess->inst()->model().Tokenizer().Decode(output, &token_text)) {
throw std::runtime_error("Tokenizer decoding failed. (stream_mode)");
}
lua_pushlstring(L, token_text.data(), token_text.size());
}
} else {
if (pos - start_pos < prompt_size) {
lua_pushnil(L);
} else if (token == gcpp::EOS_ID || sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
output.front() = token;
std::string token_text;
if (!sess->inst()->model().Tokenizer().Decode(output, &token_text)) {
throw std::runtime_error("Tokenizer decoding failed. (stream_mode)");
}
lua_pushlstring(L, token_text.data(), token_text.size());
}
lua_pushinteger(L, pos - start_pos);
lua_pushinteger(L, prompt_size);
Expand Down Expand Up @@ -92,8 +91,8 @@ int normal_mode(lua_State* L, cgemma::session* sess, const std::vector<int>& pro
std::vector<int> output;
output.reserve(sess->args().max_generated_tokens);
generate(sess, prompt, [&](size_t, size_t pos, int token, float) {
if (pos - start_pos >= prompt_size && token != gcpp::EOS_ID) {
if (sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
if (pos - start_pos >= prompt_size) {
if (token == gcpp::EOS_ID || sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
return false;
}
output.push_back(token);
Expand Down

0 comments on commit 8a1dda3

Please sign in to comment.