Skip to content

Commit

Permalink
Adapt to the runtime Top-K configuration of gemma.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Nov 14, 2024
1 parent ae2fd2f commit 4f8cbb3
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Available options and default values:
prefill_tbatch = 64, -- Prefill: max tokens per batch.
decode_qbatch = 16, -- Decode: max queries per batch.
temperature = 1.0, -- Temperature for top-K.
top_k = 1, -- Number of top-K tokens to sample from.
}
```

Expand Down Expand Up @@ -347,7 +348,7 @@ The stream function is the same as in [metatable(cgemma.session).call](#metatabl
> 1. Each element in a batch must start with a session, followed by a string and an optional stream function, with a stream function means that the corresponding session will be in stream mode instead of normal mode;
> 2. All sessions in a batch must be created by the same Gemma instance;
> 3. Sessions in a batch must not be duplicated;
> 4. Inference arguments of batch call: `max_generated_tokens`, `prefill_tbatch`, and `decode_qbatch` will be the minimum value of all sessions, and `temperature` will be the average value of all sessions.
> 4. Inference arguments of batch call: `max_generated_tokens`, `prefill_tbatch`, and `decode_qbatch` will be the minimum value of all sessions, `temperature` will be the average value of all sessions, and `top_k` will be the maximum value of all sessions.
#### cgemma.batch_result.stats

Expand Down
3 changes: 2 additions & 1 deletion demo/cgemma_demo.conf
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ http {

function gemma_loop()
local session, err = gemma_inst():session({
temperature = 0.4
temperature = 0.4,
top_k = 50
})
if not session then
ngx.log(ngx.ERR, "cgemma error: ", err)
Expand Down
2 changes: 1 addition & 1 deletion examples/ai_function.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ local function implement(...)
error("Opoos! "..err)
end
table.insert(funcs, function(...)
local session, err = gemma:session()
local session, err = gemma:session({top_k = 50})
if not session then
error("Opoos! "..err)
end
Expand Down
2 changes: 1 addition & 1 deletion examples/batch_interface.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end
-- Create 3 chat sessions
local sessions = {}
for i = 1, 3 do
local session, err = gemma:session()
local session, err = gemma:session({top_k = 50})
if not session then
error("Opoos! "..err)
end
Expand Down
2 changes: 1 addition & 1 deletion examples/normal_mode.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if args.image then
end

-- Create a chat session
local session, err = image and gemma:session(image) or gemma:session()
local session, err = image and gemma:session(image, {top_k = 50}) or gemma:session({top_k = 50})
if not session then
error("Opoos! "..err)
end
Expand Down
2 changes: 1 addition & 1 deletion examples/stream_mode.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if args.image then
end

-- Create a chat session
local session, err = image and gemma:session(image) or gemma:session()
local session, err = image and gemma:session(image, {top_k = 50}) or gemma:session({top_k = 50})
if not session then
error("Opoos! "..err)
end
Expand Down
2 changes: 2 additions & 0 deletions src/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ gcpp::RuntimeConfig parse_config(const std::vector<cgemma::session_context>& ses
cfg.prefill_tbatch_size = 4096;
cfg.decode_qbatch_size = 4096;
cfg.temperature = 0.0f;
cfg.top_k = 1;
for (const auto& ctx: sess_ctxs) {
cfg.max_generated_tokens = std::min(cfg.max_generated_tokens, ctx.sess->args().max_generated_tokens);
cfg.prefill_tbatch_size = std::min(cfg.prefill_tbatch_size, ctx.sess->args().prefill_tbatch_size);
cfg.decode_qbatch_size = std::min(cfg.decode_qbatch_size, ctx.sess->args().decode_qbatch_size);
cfg.temperature += ctx.sess->args().temperature;
cfg.top_k = std::max(cfg.top_k, ctx.sess->args().top_k);
}
cfg.temperature /= sess_ctxs.size();
return cfg;
Expand Down
1 change: 0 additions & 1 deletion src/cgemma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ int info(lua_State* L) {
auto now = std::time(nullptr);
std::cout << "Date & Time : " << std::put_time(std::localtime(&now), "%F %T") << std::endl;
std::cout << "Max Sequence Length : " << gcpp::kSeqLen << std::endl;
std::cout << "Top-K : " << gcpp::kTopK << std::endl;
char cpu[100];
if (hwy::platform::GetCpuString(cpu)) {
std::cout << "CPU : " << cpu << std::endl;
Expand Down
3 changes: 2 additions & 1 deletion src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ int session::create(lua_State* L) {
"--max_generated_tokens",
"--prefill_tbatch",
"--decode_qbatch",
"--temperature"
"--temperature",
"--top_k"
};
constexpr const int n = sizeof(available_options) / sizeof(available_options[0]);
int argc = 1;
Expand Down

0 comments on commit 4f8cbb3

Please sign in to comment.