-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : save and restore kv cache for single seq id #6341
Conversation
Thanks very much for working on this! ❤️ Would it be worth adding a header with some metadata - e.g. magic number (to check it's the right file type), file format version (in case the format is ever tweaked in the future), some kind of model checks (so you can check if the model you're about to load it into is compatible). This doesn't have to be used at the moment, but I've often regretted designing persistent file formats without reserving some space for things like that. Looking at the functions available from the point of view of a developer trying to integrate this with an external llama.cpp wrapper (LLamaSharp) my main concern is that there's no way to load a sequence without risk of crashing the process ( GGML_ASSERT(n_layer == n_layer_ref);
GGML_ASSERT(n_embd_v_gqa == n_embd_v_gqa_ref); Can this be modified to fail in a more graceful way? e.g. return an error code. |
Feel free to make the necessary adjustments to the code. :) |
I'll give it a go. I'm not very familiar with C++, but hopefully just adapting it to return some error codes should be easy enough! |
I'm already returning 0 as failure value for the case where there's no available space in the kv cache, so maybe just do that everywhere on the format reference values checks. Might be useful to output some LOG_ERROR messages. |
@kaetemi here are my proposed changes: martindevans@62370b0 |
size_t is unsigned, though, not sure if returning negative values works out here |
Oops, I'm used to Rust with I'll change all of the errors to simply return 0 as you suggested. |
@kaetemi here's a new proposed set of changes: martindevans@b182f8f |
llama.h
Outdated
LLAMA_API size_t llama_get_seq_size( | ||
struct llama_context * ctx, | ||
llama_seq_id seq_id); | ||
|
||
LLAMA_API size_t llama_copy_seq_data( | ||
struct llama_context * ctx, | ||
uint8_t * dst, | ||
llama_seq_id seq_id); | ||
|
||
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence. | ||
// Returns: | ||
// - Positive: Ok | ||
// - Zero: Failed to load | ||
LLAMA_API size_t llama_set_seq_data( | ||
struct llama_context * ctx, | ||
const uint8_t * src, | ||
llama_seq_id dest_seq_id); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@compilade, @kaetemi and others
It would be useful to update the names of the state management API - either in this PR or in another. For example:
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
LLAMA_API size_t llama_state_get_data(
struct llama_context * ctx,
uint8_t * dst);
LLAMA_API size_t llama_state_set_data(
struct llama_context * ctx,
const uint8_t * src);
LLAMA_API bool llama_state_load_file(
struct llama_context * ctx,
const char * path_session,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
LLAMA_API bool llama_state_save_file(
struct llama_context * ctx,
const char * path_session,
const llama_token * tokens,
size_t n_token_count);
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
llama_seq_id seq_id);
LLAMA_API size_t llama_state_seq_set_data(
struct llama_context * ctx,
const uint8_t * src,
llama_seq_id seq_id);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that looks a lot more orderly. I'll change that. (done)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be useful to deprecate the old functions using the DEPRECATED
macro in llama.h
and update the README section "Recent API changes" to help 3rd party devs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks. I'll do that and update the docs (done).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for briging this to the server, please add related tests scenario in the server test framework.
examples/server/server.cpp
Outdated
@@ -3519,6 +3767,12 @@ int main(int argc, char ** argv) { | |||
svr->Post("/v1/embeddings", handle_embeddings); | |||
svr->Post("/tokenize", handle_tokenize); | |||
svr->Post("/detokenize", handle_detokenize); | |||
if (!sparams.slot_save_path.empty()) { | |||
// only enable slot endpoints if slot_save_path is set | |||
svr->Post("/slot/save", handle_slot_save); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be restful compliant, the good path must be /slots/{slot-id}?action=save
, verb must not be present in the path and resources are in the plural form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I'll look into adding server test cases (done). Adjusting the endpoints (done) as well if that's the preferred style.
Not a backdoor. Promise. 😎 |
…session api and add version tags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks! Please wait for @ggerganov approval
llama.cpp
Outdated
if (n_layer != n_layer_ref) { | ||
return 0; | ||
} | ||
if (n_embd_v_gqa != n_embd_v_gqa_ref) { | ||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to print an error explaining the reason why the state cannot be set.
llama.cpp
Outdated
size_t k_size_row_ref; | ||
memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); | ||
inp += sizeof(k_size_row_ref); | ||
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); | ||
if (k_size_row != k_size_row_ref) { | ||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of saving the row size in the state? If the goal is to check that the KV data types ares the same, why not export the tensor type instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems safer for parsing to keep the actual value used to calculate the data length in a binary format. (Also keeps it easier for any third-party tool to blindly parse through it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type size is not unique. For example, q4_0 and iq4_nl have the same size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can make it stricter and add the type as well. My concern here is mainly on avoiding any chance of buffer overflows by being explicit on the sizes. (I'm also considering external tools that might want to splice together or trim sequences, which could just treat the data as a black box and only need to know the data length.)
llama_file file(filepath, "wb"); | ||
|
||
file.write_u32(LLAMA_STATE_SEQ_MAGIC); | ||
file.write_u32(LLAMA_STATE_SEQ_VERSION); | ||
|
||
// save the prompt | ||
file.write_u32((uint32_t)n_token_count); | ||
file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama_file
throws on failure, and we should avoid passing these exceptions to the user. Instead, the exceptions should be caught and an error code should be returned to the user (the same issue exists in llama_save_session_file
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(done)
Thanks, let's merge after resolving the conflicts |
Resolved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
See #5843
This adds
llama_get_seq_size
,llama_copy_seq_data
, andllama_set_seq_data
functions to save and restore the kv cache of a single sequence id.On the server this adds
/slot/save
and/slot/restore
endpoints, which will save the kv cache for the given slot along with the token cache to a file. Also added a/slot/erase
to just wipe the kv cache for a slot.Works so far,
but still needs test cases, and perhaps a path parameter for the server to enable the functionality restricted within a specified folder(done). Might be missing some special cases that I'm not aware of.