Skip to content

Commit

Permalink
Adapt to the naming changes in gemma.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Dec 14, 2024
1 parent 745e811 commit 30fded1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ int batch(lua_State* L) {
auto& ctx = sess_ctxs[query_idx];
if (ctx.stream_fn == 0) {
if (pos - ctx.start_pos >= ctx.prompt.size()) {
if (token == gcpp::EOS_ID || inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == EOT_ID) {
if (token == gcpp::EOS_ID || inst->model().Info().wrapping == gcpp::PromptWrapping::GEMMA_IT && token == EOT_ID) {
return false;
}
ctx.output.push_back(token);
Expand All @@ -191,7 +191,7 @@ int batch(lua_State* L) {
lua_pushvalue(L, ctx.stream_fn);
if (pos - ctx.start_pos < ctx.prompt.size()) {
lua_pushnil(L);
} else if (token == gcpp::EOS_ID || inst->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == EOT_ID) {
} else if (token == gcpp::EOS_ID || inst->model().Info().wrapping == gcpp::PromptWrapping::GEMMA_IT && token == EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int stream_mode(lua_State* L, cgemma::session* sess, const std::vector<int>& pro
lua_pushvalue(L, 3);
if (pos - start_pos < prompt_size) {
lua_pushnil(L);
} else if (token == gcpp::EOS_ID || sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
} else if (token == gcpp::EOS_ID || sess->inst()->model().Info().wrapping == gcpp::PromptWrapping::GEMMA_IT && token == cgemma::EOT_ID) {
eot = true;
lua_pushnil(L);
} else {
Expand Down Expand Up @@ -92,7 +92,7 @@ int normal_mode(lua_State* L, cgemma::session* sess, const std::vector<int>& pro
output.reserve(sess->args().max_generated_tokens);
generate(sess, prompt, [&](size_t, size_t pos, int token, float) {
if (pos - start_pos >= prompt_size) {
if (token == gcpp::EOS_ID || sess->inst()->model().Info().training == gcpp::ModelTraining::GEMMA_IT && token == cgemma::EOT_ID) {
if (token == gcpp::EOS_ID || sess->inst()->model().Info().wrapping == gcpp::PromptWrapping::GEMMA_IT && token == cgemma::EOT_ID) {
return false;
}
output.push_back(token);
Expand Down Expand Up @@ -329,7 +329,7 @@ std::vector<int> session::tokenize(const char* text, size_t len) const {
constexpr const char model_sot[] = "<start_of_turn>model\n";
constexpr const char eot[] = "<end_of_turn>\n";
std::string s;
if (inst_->model().Info().training == gcpp::ModelTraining::GEMMA_IT) {
if (inst_->model().Info().wrapping == gcpp::PromptWrapping::GEMMA_IT) {
s.reserve(sizeof(eot) - 1
+ sizeof(user_sot) - 1
+ len
Expand Down Expand Up @@ -359,7 +359,7 @@ std::vector<int> session::tokenize(const char* text, size_t len) const {
if (pos_ == 0) {
prompt.emplace(prompt.cbegin(), gcpp::BOS_ID);
}
if (inst_->model().Info().training == gcpp::ModelTraining::PALIGEMMA) {
if (inst_->model().Info().wrapping == gcpp::PromptWrapping::PALIGEMMA) {
std::vector<int> sep;
if (!inst_->model().Tokenizer().Encode("\n", &sep)) {
throw std::runtime_error("Tokenizer encoding failed. (session::tokenize)");
Expand Down

0 comments on commit 30fded1

Please sign in to comment.