-
Notifications
You must be signed in to change notification settings - Fork 913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More cache improvements #1015
More cache improvements #1015
Conversation
…ases for e.g. caching during a chat
I also added a chat command to MLX LM which is a good use case for the prompt cache re-use. The example is kind of fun to play with:
Then you can just chat with the model and it preserves the history and doesn't do any prompt recomputations..
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fantastic refactoring! Especially love the tests.
I am wondering what is ht point of the extra state in the KV cace? Is anybody using it now? Is there any reason it is set to the empty string instead of None
?
I may have missed a discussion but I couldn't find its use in the code as well.
) -> mx.array: | ||
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache) | ||
r = self.self_attn(self.input_layernorm(x), mask, cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
@@ -198,20 +197,12 @@ def __call__( | |||
self, | |||
x: mx.array, | |||
mask: mx.array = None, | |||
cache: mx.array = None, | |||
) -> Tuple[mx.array, mx.array]: | |||
cache=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't personally care but in most cases this is written as cache: Optional[Any] = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, will fix.
""" | ||
cache_data, cache_info = zip(*(c.state for c in cache)) | ||
cache_data = dict(tree_flatten(cache_data)) | ||
cache_classes = [type(c).__name__ for c in cache] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
It's my least favorite thing in this diff, but I didn't think of a cleaner solution yet (if you have ideas I'm all 👂 ) It is used only for the The reason I made it a string and not |
Huh, not sure how I managed to miss it in the I think this should be changed because the Nothing changes much. Line 47 would do sth like cache_data = [c.state for c in cache]
cache_info = [c.serialization_state for c in cache] # do we also want type(self).__name__ here? and lines 75 would change to for c, state, serialization_state in zip(cache, arrays, info):
c.state = state
c.serialization_state = serialization_state The rest remains the same... Wdyt? |
Yea I thought about a separate property.. and/or overriding |
I added a small base class that implements the empty meta state and makes the load/save code a tad bit cleaner? Should I push it on top or we are avoiding base classes for some reason? Also just played with the chat command it is an absolute joy to use :-) |
No reason at this point, please send the diff! |
Ok, I tested prompt caching with a few different models / cache types and it seems to work well. I'm going to merge this. As a follow up we should consider:
|
Awesome work, thanks for fixing it! |
So excited this got merged 😄 |
Sorry for the large diff. There's a lot of boiler plate / moving stuff around which accounts for most of it.
The main bits are:
KVCache
)Closes #1000