Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Add KV-cache creation capability to mlx_lm.generate for after a text completion #1001

Closed
wants to merge 5 commits into from

Conversation

mark-lord
Copy link

Whilst mlx_lm.cache_prompt lets you encode a prompt and save the key value pairs as a kv-cache.safetensors file in advance, there's currently no means of saving the kv-cache after a text completion by an LLM.

Adding this will bring MLX_lm more in line with Llama.cpp in terms of reducing latency in multi-turn scenarios. For example, using an MLX-served model as a chatbot and having a drawn out discussion about a given topic. Saving the KV-cache after each turn by the LLM means that even as the conversation history continues, there won't be any latency introduced by having to re-encode the entire chat log again - only the most recent user prompt.

Not sure I went about it the best way, but it seems to work from my testing! There's one superfluous edit to the step generator line 357 which can probably be left out, but otherwise I think I kept this as streamlined as I could.

description of changes
Change to enable saving the kv-cache as a safetensors file after a text completion; after generate step has finished creating all the tokens, the key values cache is made into a dict and saved using mx.save_safetensors to a user-specified file location; similar to cache_prompt.
Add functionality to take --save-kv-cache 'path/to/cache' and pass it to the modified utils.py
@mark-lord
Copy link
Author

Oop, forgot to add explanation of how to use. Copied + pasted from my explanation in the MLX Discord:

On first time initialisation of a chat you'll need to create the cache first for the model to use, which can be done with the normal cache prompt.py:

mlx_lm.cache_prompt --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --prompt "From now on, talk like a pirate"

And then after you'd do something like

mlx_lm.generate --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --kv-cache-file "rollingcache.safetensors" --prompt "Tell me a joke"

Which creates the generation:
Arrr, settle yerself down with a pint o' grog and listen close, me hearty. Here be a joke fer ye:
Why did the pirate quit his job?
(pause for dramatic effect)
Because he was sick o' all the arrrr-guments! (wink) Savvy?

Seems to work perfectly fine on a rolling basis. For example, you can then follow up with:

mlx_lm.generate --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --kv-cache-file "rollingcache.safetensors" --prompt "Explain that joke to me"

To which you'll get the reply,
Alright then, matey, settle yerself down with a pint o' grog and listen close. Ye want to know the joke, eh? Alright then, I'll give it to ye. So, I be tellin' ye this one: Why'd the pirate quit his job? (pauses for dramatic effect)
He be sick o' all the arrrr-guments, savvy? Arrrr, get it? It's a pirate pun, matey!

By having --save-kv-cache and --kv-cache-file as the same, It'll load the kv-cache file, take your new prompt and generate from it, then overwrite the original kv-cache file. Thereby making a rolling kv-cache. Even as the conversation history grows, time to first token is kept low as there's no need to re-encode.

@mark-lord
Copy link
Author

Oop, didn't realise there was already a PR on this - #989 Haven't had time to take a look yet but from what I gather it might be tackling largely the same thing

@awni
Copy link
Member

awni commented Oct 6, 2024

Thanks for the PR! I changed a bit how the caching works in MLX LM to make this exact use case much easier.

You can see the PR #1015 and example there. It will make avoiding recomputing KV caches and serializing them to disk much cleaner / easier.

It should mostly subsume the case you fixed here so I will close this.

@awni awni closed this Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants