Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rotating kv cache size #1093

Merged
merged 1 commit into from
Nov 5, 2024
Merged

Fix rotating kv cache size #1093

merged 1 commit into from
Nov 5, 2024

Conversation

angeloskath
Copy link
Member

Currently the rotating KV cache can be both max_size and max_size - 1 depending on whether we filled it by generating or by prompt token processing. This PR makes sure the max is always max_size and not max_size - 1.

The following two simple repros exemplify the problem a bit:

Generate past max size

from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.utils import load
from mlx_lm.utils import stream_generate

PROMPT = "Tell me a story"

model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")

cache = make_prompt_cache(model, max_kv_size=100)
for text in stream_generate(model, tokenizer, PROMPT, prompt_cache=cache, max_tokens=100):
    pass
for text in stream_generate(model, tokenizer, PROMPT, prompt_cache=cache, max_tokens=100):
    pass

Prompt process past max size

from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.utils import load
from mlx_lm.utils import stream_generate

PROMPT = "Tell me a story "

model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")

cache = make_prompt_cache(model, max_kv_size=512)
for text in stream_generate(model, tokenizer, PROMPT * 384, prompt_cache=cache, max_tokens=50):
    pass

The first breaks on main. Then changing line 45 on base.py breaks the 2nd test and changing the trim size makes sure everything works. Btw these also occur in mlx_lm.chat and mlx_lm.cache_prompt it is just more explicit with the tests above.

@angeloskath angeloskath requested a review from awni November 5, 2024 07:22
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good thanks for digging into that and fixing it!

@angeloskath angeloskath merged commit ed9e81d into main Nov 5, 2024
2 checks passed
@angeloskath angeloskath deleted the fix-rotating-cache branch November 5, 2024 18:24
zcbenz added a commit to frost-beta/llm.js that referenced this pull request Dec 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants