Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More cache improvements #1015

Merged
merged 12 commits into from
Oct 8, 2024
Merged

More cache improvements #1015

merged 12 commits into from
Oct 8, 2024

Conversation

awni
Copy link
Member

@awni awni commented Oct 5, 2024

Sorry for the large diff. There's a lot of boiler plate / moving stuff around which accounts for most of it.

The main bits are:

  • Fix RotatingKVCache for alternating chat, response use case
  • Enable prompt caching for all types (not just KVCache)
  • Unify APIs and cache types in a single file for ease of use / consistency.
  • Chat mode allows prompt caching for efficiency. Example here.
  • Add a bunch of tests.

Closes #1000

@awni
Copy link
Member Author

awni commented Oct 7, 2024

I also added a chat command to MLX LM which is a good use case for the prompt cache re-use. The example is kind of fun to play with:

mlx_lm.chat

Then you can just chat with the model and it preserves the history and doesn't do any prompt recomputations..

[INFO] Starting chat sessiong with mlx-community/Llama-3.2-3B-Instruct-4bit. To exit, enter 'q'.
>> Hi, my name is Awni!
Hi Awni! It's nice to meet you. Is there something I can help you with or would you like to chat?
>> What's the tallest mountain in the world?
The tallest mountain in the world is Mount Everest, which is located in the Himalayas on the border between Nepal and Tibet, China. It stands at an elevation of 8,848 meters (29,029 feet) above sea level.
>> Do you remember my name?
Yes, your name is Awni.
>> Nice talking with you!
It's great to chat with you too, Awni! Is there anything else you'd like to talk about or ask about?
>> 

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fantastic refactoring! Especially love the tests.

I am wondering what is ht point of the extra state in the KV cace? Is anybody using it now? Is there any reason it is set to the empty string instead of None?

I may have missed a discussion but I couldn't find its use in the code as well.

.gitignore Show resolved Hide resolved
) -> mx.array:
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@@ -198,20 +197,12 @@ def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't personally care but in most cases this is written as cache: Optional[Any] = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will fix.

"""
cache_data, cache_info = zip(*(c.state for c in cache))
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@awni
Copy link
Member Author

awni commented Oct 7, 2024

I am wondering what is ht point of the extra state in the KV cace? Is anybody using it now? Is there any reason it is set to the empty string instead of None?

It's my least favorite thing in this diff, but I didn't think of a cleaner solution yet (if you have ideas I'm all 👂 )

It is used only for the RotatingKVCache so that we can save the cache.offset and cache._idx. Otherwise we don't get the right/same behavior when serializing and deserializing that cache.

The reason I made it a string and not None is because that simplified saving it in safetensors metadata. So downstream code just does something like dict(tree_flatten([c for c in cache.state[1]])).

@angeloskath
Copy link
Member

Huh, not sure how I managed to miss it in the RotatingKVCache...

I think this should be changed because the state is what we evaluate from the caches so this is why I was confused with the string. The most minor change that would be imho significantly better would be to split it into state and serialization_state. It is a bit verbose but at least it separates the two types of information cleanly.

Nothing changes much. Line 47 would do sth like

cache_data = [c.state for c in cache]
cache_info = [c.serialization_state for c in cache] # do we also want type(self).__name__ here?

and lines 75 would change to

for c, state, serialization_state in zip(cache, arrays, info):
    c.state = state
    c.serialization_state = serialization_state

The rest remains the same... Wdyt?

@awni
Copy link
Member Author

awni commented Oct 7, 2024

Yea I thought about a separate property.. and/or overriding __getstate__ and __setstate__. The main downside I didn't like is that all the caches needed to implement it.. but maybe the right call is to check if the attribute exists to avoid that. I think you're right it could be cleaner even if a little more verbose.

@angeloskath
Copy link
Member

I added a small base class that implements the empty meta state and makes the load/save code a tad bit cleaner? Should I push it on top or we are avoiding base classes for some reason?

Also just played with the chat command it is an absolute joy to use :-)

@awni
Copy link
Member Author

awni commented Oct 7, 2024

I added a small base class that implements the empty meta state and makes the load/save code a tad bit cleaner? Should I push it on top or we are avoiding base classes for some reason?

No reason at this point, please send the diff!

@awni
Copy link
Member Author

awni commented Oct 8, 2024

Ok, I tested prompt caching with a few different models / cache types and it seems to work well. I'm going to merge this.

As a follow up we should consider:

  • Making a way to serialize the chat context
  • Adding a chat endpoint to mlx_lm.server with prompt caching

@awni awni merged commit fca087b into main Oct 8, 2024
2 checks passed
@awni awni deleted the more_cache_improvements branch October 8, 2024 03:45
@zcbenz
Copy link
Contributor

zcbenz commented Oct 8, 2024

Awesome work, thanks for fixing it!

@mark-lord
Copy link

So excited this got merged 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RotatingKVCache: Problem when reusing cache between multiple generations
4 participants