-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DRY penalty #637
Conversation
Code Metrics Report=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 2 35 28 0 7 Dockerfile 1 34 25 0 9 Happy 1 442 369 0 73 JSON 11 102 101 0 1 Python 46 2018 1718 62 238 TOML 20 619 546 11 62 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 4 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 196 169 1 26 (Total) 273 201 32 40 ------------------------------------------------------------------------------- Markdown 29 2063 0 1568 495 |- BASH 5 101 98 0 3 |- JSON 1 12 12 0 0 |- Python 5 92 82 0 10 |- Rust 6 408 365 19 24 |- TOML 2 75 63 0 12 (Total) 2751 620 1587 544 ------------------------------------------------------------------------------- Rust 198 61913 56251 1123 4539 |- Markdown 102 946 13 881 52 (Total) 62859 56264 2004 4591 =============================================================================== Total 315 67247 59057 2766 5424 =============================================================================== |
Thank you for implementing this so quickly! I have submitted my review in the form of a pull request into this branch: #645. |
* Silence bogus Clippy warning Clippy's suggestion cannot be implemented because of borrowing issues * Get rid of unnecessary type annotations Interesting that Clippy doesn't catch this * Store default sequence breakers in a slice It's nicer when the length is not hardcoded * Make default sequence breakers private No need to leak this as it's not used elsewhere * Limit match length Avoids quadratic runtime and potential DoS with adversarial inputs Ref oobabooga/text-generation-webui#6047 * "Fix" sequence breaker tokenization Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result. See LostRuins/koboldcpp#982 for a correct solution to this problem.
Hey @p-e-w! Do you think this is ready to merge (that is, the implementation is done & correct)? |
Give me a few days to test and verify this, I will let you know once I'm sure. Looks fine at first glance though! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some optional suggestions, I've not really done a proper review.
@polarathene thanks for the review - nice to see you back! I've merged them, thanks for the suggestions! |
I'm having a hard time testing this properly because of #666. |
mistralrs-core/src/sampler.rs
Outdated
@@ -488,7 +488,7 @@ impl Sampler { | |||
let match_indices = toks | |||
.par_iter() | |||
.enumerate() | |||
.take(toks.len() - 1) | |||
.take(toks.len().saturating_sub(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just ran into this problem before I noticed your commit. But why can toks.len()
be 0 here in the first place? That doesn't make sense to me. If I enter a prompt (e.g. "test"), then apply_dry_penalty
is called with an empty context. Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this can only be caused when sampling the token resulting from prompt processing. We discard all prompt tokens from the penalty context:
Does DRY require those tokens to be included?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid I don't understand. Why is sampling needed during prompt processing?
Sure, the transformer will by construction always output logits for each token position. But samplers are only getting involved once a token is actually drawn from the resulting distribution. And that only happens when new tokens are generated, right? In which case the context should always be non-empty after a prompt has been entered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is sampling needed during prompt processing?
After the model processes the prompt, it produces a token distribution as a result and we sample that. The Sampler::apply_penalties
method has inherent support for the case where the generated tokens do not exist yet, as we just iterate over that context. Perhaps we need a case here to handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel bad for taking so much of your time with this, but I just don't get it. Why does sampling ever happen with an empty context? The only purpose of sampling (i.e., drawing from the probability distribution, rather than merely generating the distribution) is to generate a new token, no?
Let's say the user enters the prompt "Hello", and then runs generation. Why is Sampler::apply_penalties
called with an empty context, rather than with context "Hello"? The program doesn't need to sample during prompt processing (since no tokens are being generated at previous positions), so why is this happening?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, sorry for any confusion.
Let's say the user enters the prompt "Hello", and then runs generation. Why is Sampler::apply_penalties called with an empty context, rather than with context "Hello"? The program doesn't need to sample during prompt processing (since no tokens are being generated at previous positions), so why is this happening?
We process the prompt and then sample the distribution of the last token to get the next token. So if we don't exclude the prompt, the context would be the prompt here. This is intentional, but do you think we should remove the part where we exclude the prompt (below)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I understand now. So this is intended to be a feature?
It seems that this breaks many assumptions made by frontends (and users) about how things work. The most problematic effect is that it makes sampling a stateful affair, where the original prompt range is somehow "remembered" during generation, and the context given to the samplers doesn't match the input given to the transformer.
In a chat interface, this means repetition penalties cannot take into account previous messages while generating a new one... which is the exact opposite of what we want, since models tend to repeat previous messages. But it also seems wrong philosophically. If you cancel the generation process midway, and then restart it from that position, you might get different tokens than you would have if you had allowed it to complete, because now the prompt includes the partially generated output.
I don't believe other loaders do this, and IMO this mechanism should indeed be removed completely, for all samplers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out, I agree that the statefulness makes this incorrect!
I don't believe other loaders do this, and IMO this mechanism should indeed be removed completely, for all samplers.
Sounds good! I've removed this functionality completely now.
@p-e-w, does this look correct?
|
Credit to @p-e-w for finding this! Co-authored-by: Philipp Emanuel Weidmann <pew@worldwidemann.com>
* Add custom logits processor api * Typos * Nicer interface and update example * Fix doctest * Update docs
* Add gemma2 paged attn support * Non cuda support? * Remove error * It works
* Support GGUF bf16 tensors * Fix loading of bf16 ggml tensor * Fix dequant of bf16 * Use merged rev
…on (#707) * Flash attention varlen kind of works * Seems to work * Now it's nice * Sliding window support and clippy * Remove warning * Support smollm * Update rev to match merged
* Update image_seq_len * Update the examples * Format
* Copy the model * Add most of it * Add the blocksparse moe parts * Clippy * Fix mscales * A batch of fixes * Correctly cast it * Handle isq on gate * Even more progress * Runs now * Clippy * Fix to use layernorm * Remove unused * Add docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Verified that the code is equivalent to the original Python implementation.
- Instrumented
apply_dry_penalty
and checked that it produces the expected penalties. - Tested with repetitive output and validated that this version of DRY actually prevents repetition.
As far as I'm concerned, this is now ready to be merged (after applying the two changes above).
One thing you might consider is to disable DRY completely if multiplier
is 0, which is what other implementations are doing. Currently, matching is still performed in this case, but has no effect because the resulting penalty is zero. That's a lot of unnecessary work that could be skipped by just not invoking apply_dry_penalty
in the first place (and the same optimization could be applied for apply_freq_presc_penalty
I think).
@p-e-w @polarathene thank you for your reviews! I'll merge this PR as it looks good and generation is great with it! |
@p-e-w, could you please give the implementation a quick check? I'm not sure if you are familiar with Rust, but I ported the algorithm from the oobabooga implemenation you linked.
Refs #635.