Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow prompt callback to generate_step #1133

Merged
merged 4 commits into from
Dec 4, 2024
Merged

Allow prompt callback to generate_step #1133

merged 4 commits into from
Dec 4, 2024

Conversation

awni
Copy link
Member

@awni awni commented Dec 2, 2024

  • Add prompt callback to generate step
  • Use the callback in cache_prompt (simplifies some code)
  • Refactor generate_step to accept max_tokens.

CC @neilmehta24, the prompt callback takes two arguments (the prompt tokens processed and the total prompt tokens) as it's more flexible. Let me know if that works for you.

@neilmehta24
Copy link
Contributor

That will work great. Thanks for this feature!

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

I would change the n < max_tokens to n != max_tokens and document the negative max_tokens and I think it is golden :-)

llms/mlx_lm/utils.py Outdated Show resolved Hide resolved
llms/mlx_lm/utils.py Show resolved Hide resolved
llms/mlx_lm/utils.py Outdated Show resolved Hide resolved
maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
for _ in generate_step(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does an unnecessary sampling but I like very much that it removes so much duplicated logic.

Copy link
Member Author

@awni awni Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it does the async eval on the first token / log probs. It should be negligible compared to the rest of the computation..

It's nice it simplifies the logic a bit.. mostly I changed this to have an example use of the callback and verify that it was working correctly.

Another option I considered is to split out a prefill_prompt API which generate_step can use and which cache_prompt.py can use. Maybe it's better, but so far we don't need to use it anywhere except here. If we end up needing to do the prompt computation in more places I will play around with that.

@awni
Copy link
Member Author

awni commented Dec 3, 2024

Also closes #1134

@awni awni merged commit 1963df8 into main Dec 4, 2024
2 checks passed
@awni awni deleted the prompt_callback branch December 4, 2024 00:17
mokeddembillel pushed a commit to mokeddembillel/mlx-examples that referenced this pull request Dec 16, 2024
* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants