Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : reuse context chunks #9866

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_threads_http = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
add_opt(common_arg(
{"--cache-reuse"}, "N",
string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse),
[](common_params & params, int value) {
params.n_cache_reuse = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE"));
add_opt(common_arg(
{"--metrics"},
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ struct common_params {
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ The project is under active development, and we are [looking for feedback and co
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
| `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)<br/>(env: LLAMA_ARG_CACHE_REUSE) |
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
Expand Down
69 changes: 66 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ struct server_context {
int slot_prompt_len = slot_prompt.size();

// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = common_part(slot_prompt, prompt);
int lcp_len = longest_common_prefix(slot_prompt, prompt);

// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
Expand Down Expand Up @@ -2012,7 +2012,7 @@ struct server_context {
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);

// if input prompt is too big, truncate it (if group attention self-extend is disabled)
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;

Expand Down Expand Up @@ -2042,12 +2042,74 @@ struct server_context {

if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);

// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}

// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt

SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);

while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
if (llama_token_is_control(model, slot.cache_tokens[head_c])) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p])) {
break;
}

size_t n_match = 0;

while (head_c + n_match < slot.cache_tokens.size() &&
head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match])) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p + n_match])) {
break;
}

n_match++;
}

if (n_match >= (size_t) params.n_cache_reuse) {
SLT_DBG(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
//}

const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;

llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);

for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];

common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);

slot.n_past++;
}

head_c += n_match;
head_p += n_match;
} else {
head_c += 1;
}
}

SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
}
}
}

Expand Down Expand Up @@ -3257,6 +3319,7 @@ int main(int argc, char ** argv) {

ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

ctx_server.queue_tasks.on_update_slots(std::bind(
&server_context::update_slots, &ctx_server));

Expand Down
4 changes: 2 additions & 2 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ static std::string gen_chatcmplid() {
// other common utils
//

static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}

return i;
}

static size_t common_part(const std::string & a, const std::string & b) {
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}

Expand Down
Loading