Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added implementation of DRY sampler #6839

Closed
wants to merge 21 commits into from

Conversation

l3utterfly
Copy link
Contributor

As seen here: oobabooga/text-generation-webui#5677

@p-e-w I took the liberty of porting your DRY sampler to c++ and adding it to llama.cpp

This implementation is used directly after the repeat penalty in llama.cpp, and uses the same repeat_penalty_n for DRY_range.

Thank you for the original implementation, from my tests it works amazing. It is especially useful in mobile where we can only run small quants, and smaller quants cause excessive repetition from my experience.

@Fristender
Copy link

Thanks for the port! Hopefully this can get merged soon.

Copy link

@p-e-w p-e-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for getting this started! I would have made a PR here myself eventually, and in fact I might have written the original implementation for llama.cpp, but my two previous sampler-related PRs (1, 2) have received very little maintainer feedback with no way forward, and I don't enjoy putting effort into a black hole.

I have a pretty major performance concern with this implementation, which I have detailed above. Once that is addressed (which will presumably require a fairly comprehensive refactor) I will do a more in-depth review.

Regardless of the outcome of this PR, I'm happy that people like DRY! Repetition penalties are unfortunately still needed even with Llama 3, but we can do a lot better than the current default ones.

llama.cpp Outdated
@@ -13044,6 +13044,63 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) {
// loop through each candidate
for (size_t i = 0; i < candidates->size; ++i) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm misunderstanding something, this implementation is highly inefficient.

You're looping through all candidates. Which, before truncation samplers are applied, are simply the full vocabulary of the tokenizer. For today's LLMs, that's typically between 32k and 200k.

Then in the inner loop, you're going over all tokens in the context. Which can be tens of thousands of tokens as well.

So we're talking at minimum (candidates->size - seq_breakers_size) * last_tokens_size iterations, which is possibly hundreds of millions. And that's not even counting the innermost sequence matching loop!

My original implementation works like this instead:

  1. Find all occurrences of the last token in the context. This is typically a small percentage of the context size.
  2. For each occurrence, try to extend the match backwards as far as possible. Which is usually 0-2 tokens.
  3. Apply penalties to the tokens following the relevant matches.

In practice, this amounts to a single search over the context, plus a small overhead where matches are found. That's five orders of magnitude less work than your current implementation, assuming a vocabulary size in the tens of thousands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! Your method makes much more sense.

I think I didn't notice this because I was using the default top_k = 40 in llama.cpp, so it was only 40 tokens max 😂

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DRY should be applied before any truncation samplers such as top_k. That's because you want sufficiently penalized tokens to be truncated away by those samplers. Otherwise you can be left with a distribution that violates the coherence properties that truncation is supposed to provide.

So this is another important thing that needs to be changed.

Copy link
Contributor Author

@l3utterfly l3utterfly Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@p-e-w I have updated the code with your suggestions.

Looking more closely, I believe I misunderstood some of your logic in my last implementation (which was pretty naive of me 😂), I've refactored the sampler to better match your original implementation. Please take a look when you get the chance!

Currently, DRY is applied just after the repetition penalty, before any top k and temperature samplers.

@l3utterfly
Copy link
Contributor Author

I messed up the merge on my branch, this one now contains a huge amount of changes. Not sure how to reset it, so I'm going to create a new PR from the latest master branch in llama.cpp.

Closing this

@l3utterfly l3utterfly closed this Apr 25, 2024
@l3utterfly l3utterfly deleted the dry-sampler branch April 25, 2024 06:51
@l3utterfly l3utterfly restored the dry-sampler branch April 26, 2024 02:24
@l3utterfly
Copy link
Contributor Author

Re-opened PR after force push

@p-e-w
Copy link

p-e-w commented Apr 26, 2024

Perfect. Now force-push again and we should be in the clear.

Copy link

@p-e-w p-e-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a fairly thorough review and can confirm that, modulo some small issues, this is a faithful translation of my original Python implementation. I did not actually run the code though.

Please consider adding a separate dry_range parameter to control the penalty range independently of the standard repetition penalties. This is what I had originally implemented but it was removed by the maintainer; a discussion about this is ongoing in the original PR.

const bool penalize_nl = params.penalize_nl;

// DRY sampler parameters
const float dry_multiplier = params.dry_multiplier;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too much indentation before the assignment operator for this block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation here is trying to match the ones below from dry_allowed_length. What is the convention here?

image

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean? Both the blocks above the new parameters (ending with penalize_nl) and below them (starting with prev) have the equals signs left-aligned as closely as possible to the LHS, whereas the new parameters have three extra spaces.

But code style is really the maintainers' business. I don't care that much, just something I noticed.

common/sampling.cpp Outdated Show resolved Hide resolved
common/sampling.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated
@@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
// sanity check
GGML_ASSERT(last_tokens_size > 0);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all models use BOS tokens? Because if not, this assertion might fail with an empty context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced this with an if check instead. I'm not sure if all models use BOS tokens.

llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated
// if the match length is greater than our allowed length in config, we apply penalities
if(match_length > dry_allowed_length) {

// find our next token in the candidates->data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't the candidates indices equal to the token ID? In Transformers, this is the case, which is why the original PR doesn't need to search.

If this isn't true for llama.cpp, how are the candidates ordered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked through the creation of candidates. It appears it is true for this case (that the token ID = indices), but it may not always be true. It appears the candidates structure has a flag bool sorted, if it's true, then the candidates are sorted by logits descending.

We can check for that condition here? But I cannot determine if the candidates are guaranteed to have indices = token ID if sorted = false

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I guess the purpose of sorting by logit is to simplify truncation samplers.

Probably best to keep the current code then. There are of course possible optimizations (such as interchanging the two loops and deleting tokens from match_lengths once they have been found, which should roughly cut the execution time in half), but I'm not sure if they are worth the extra complexity.

llama.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Apr 29, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 206 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=23288.94ms p(95)=41190.24ms fails=, finish reason: stop=96 truncated=110
  • Prompt processing (pp): avg=268.12tk/s p(95)=721.7tk/s
  • Token generation (tg): avg=18.93tk/s p(95)=26.07tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=dry-sampler commit=49e078f79d38da31cdef7f8c4eb80ce25ade3ecf

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 206 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1714359527 --> 1714360159
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 280.19, 280.19, 280.19, 280.19, 280.19, 310.97, 310.97, 310.97, 310.97, 310.97, 513.58, 513.58, 513.58, 513.58, 513.58, 570.2, 570.2, 570.2, 570.2, 570.2, 638.91, 638.91, 638.91, 638.91, 638.91, 641.19, 641.19, 641.19, 641.19, 641.19, 633.53, 633.53, 633.53, 633.53, 633.53, 627.45, 627.45, 627.45, 627.45, 627.45, 620.0, 620.0, 620.0, 620.0, 620.0, 640.41, 640.41, 640.41, 640.41, 640.41, 636.5, 636.5, 636.5, 636.5, 636.5, 640.97, 640.97, 640.97, 640.97, 640.97, 672.8, 672.8, 672.8, 672.8, 672.8, 677.13, 677.13, 677.13, 677.13, 677.13, 674.31, 674.31, 674.31, 674.31, 674.31, 674.07, 674.07, 674.07, 674.07, 674.07, 673.04, 673.04, 673.04, 673.04, 673.04, 670.41, 670.41, 670.41, 670.41, 670.41, 670.99, 670.99, 670.99, 670.99, 670.99, 668.44, 668.44, 668.44, 668.44, 668.44, 678.08, 678.08, 678.08, 678.08, 678.08, 680.71, 680.71, 680.71, 680.71, 680.71, 684.55, 684.55, 684.55, 684.55, 684.55, 681.87, 681.87, 681.87, 681.87, 681.87, 680.12, 680.12, 680.12, 680.12, 680.12, 686.96, 686.96, 686.96, 686.96, 686.96, 687.49, 687.49, 687.49, 687.49, 687.49, 685.77, 685.77, 685.77, 685.77, 685.77, 698.0, 698.0, 698.0, 698.0, 698.0, 696.7, 696.7, 696.7, 696.7, 696.7, 694.6, 694.6, 694.6, 694.6, 694.6, 701.45, 701.45, 701.45, 701.45, 701.45, 708.79, 708.79, 708.79, 708.79, 708.79, 708.21, 708.21, 708.21, 708.21, 708.21, 706.4, 706.4, 706.4, 706.4, 706.4, 707.37, 707.37, 707.37, 707.37, 707.37, 709.14, 709.14, 709.14, 709.14, 709.14, 709.15, 709.15, 709.15, 709.15, 709.15, 707.0, 707.0, 707.0, 707.0, 707.0, 699.77, 699.77, 699.77, 699.77, 699.77, 698.25, 698.25, 698.25, 698.25, 698.25, 697.18, 697.18, 697.18, 697.18, 697.18, 696.2, 696.2, 696.2, 696.2, 696.2, 695.27, 695.27, 695.27, 695.27, 695.27, 696.54, 696.54, 696.54, 696.54, 696.54, 696.75, 696.75, 696.75, 696.75, 696.75, 696.59, 696.59, 696.59, 696.59, 696.59, 695.83, 695.83, 695.83, 695.83, 695.83, 695.27, 695.27, 695.27, 695.27, 695.27, 693.58, 693.58, 693.58, 693.58, 693.58, 699.13, 699.13, 699.13, 699.13, 699.13, 700.79, 700.79, 700.79, 700.79, 700.79, 698.65, 698.65, 698.65, 698.65, 698.65, 698.04, 698.04, 698.04, 698.04, 698.04, 698.38, 698.38, 698.38, 698.38, 698.38, 697.86, 697.86, 697.86, 697.86, 697.86, 697.58, 697.58, 697.58, 697.58, 697.58, 697.27, 697.27, 697.27, 697.27, 697.27, 699.9, 699.9, 699.9, 699.9, 699.9, 699.24, 699.24, 699.24, 699.24, 699.24, 700.53, 700.53, 700.53, 700.53, 700.53, 700.53, 700.53]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 206 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1714359527 --> 1714360159
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.58, 28.58, 28.58, 28.58, 28.58, 26.79, 26.79, 26.79, 26.79, 26.79, 25.99, 25.99, 25.99, 25.99, 25.99, 25.12, 25.12, 25.12, 25.12, 25.12, 22.54, 22.54, 22.54, 22.54, 22.54, 20.23, 20.23, 20.23, 20.23, 20.23, 18.56, 18.56, 18.56, 18.56, 18.56, 17.18, 17.18, 17.18, 17.18, 17.18, 17.19, 17.19, 17.19, 17.19, 17.19, 17.19, 17.19, 17.19, 17.19, 17.19, 17.84, 17.84, 17.84, 17.84, 17.84, 18.2, 18.2, 18.2, 18.2, 18.2, 18.24, 18.24, 18.24, 18.24, 18.24, 18.24, 18.24, 18.24, 18.24, 18.24, 17.91, 17.91, 17.91, 17.91, 17.91, 17.76, 17.76, 17.76, 17.76, 17.76, 17.97, 17.97, 17.97, 17.97, 17.97, 18.02, 18.02, 18.02, 18.02, 18.02, 18.44, 18.44, 18.44, 18.44, 18.44, 18.54, 18.54, 18.54, 18.54, 18.54, 18.72, 18.72, 18.72, 18.72, 18.72, 18.76, 18.76, 18.76, 18.76, 18.76, 18.81, 18.81, 18.81, 18.81, 18.81, 18.9, 18.9, 18.9, 18.9, 18.9, 18.98, 18.98, 18.98, 18.98, 18.98, 19.01, 19.01, 19.01, 19.01, 19.01, 19.03, 19.03, 19.03, 19.03, 19.03, 19.04, 19.04, 19.04, 19.04, 19.04, 19.02, 19.02, 19.02, 19.02, 19.02, 19.02, 19.02, 19.02, 19.02, 19.02, 18.96, 18.96, 18.96, 18.96, 18.96, 18.88, 18.88, 18.88, 18.88, 18.88, 18.68, 18.68, 18.68, 18.68, 18.68, 18.62, 18.62, 18.62, 18.62, 18.62, 18.6, 18.6, 18.6, 18.6, 18.6, 18.56, 18.56, 18.56, 18.56, 18.56, 18.47, 18.47, 18.47, 18.47, 18.47, 18.32, 18.32, 18.32, 18.32, 18.32, 18.09, 18.09, 18.09, 18.09, 18.09, 18.09, 18.09, 18.09, 18.09, 18.09, 17.97, 17.97, 17.97, 17.97, 17.97, 17.64, 17.64, 17.64, 17.64, 17.64, 17.38, 17.38, 17.38, 17.38, 17.38, 17.35, 17.35, 17.35, 17.35, 17.35, 17.32, 17.32, 17.32, 17.32, 17.32, 17.33, 17.33, 17.33, 17.33, 17.33, 17.34, 17.34, 17.34, 17.34, 17.34, 17.4, 17.4, 17.4, 17.4, 17.4, 17.46, 17.46, 17.46, 17.46, 17.46, 17.52, 17.52, 17.52, 17.52, 17.52, 17.53, 17.53, 17.53, 17.53, 17.53, 17.49, 17.49, 17.49, 17.49, 17.49, 17.49, 17.49, 17.49, 17.49, 17.49, 17.4, 17.4, 17.4, 17.4, 17.4, 17.34, 17.34, 17.34, 17.34, 17.34, 17.33, 17.33, 17.33, 17.33, 17.33, 17.42, 17.42, 17.42, 17.42, 17.42, 17.45, 17.45, 17.45, 17.45, 17.45, 17.48, 17.48, 17.48, 17.48, 17.48, 17.56, 17.56, 17.56, 17.56, 17.56, 17.65, 17.65]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 206 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1714359527 --> 1714360159
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.08, 0.08, 0.08, 0.08, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.31, 0.31, 0.31, 0.31, 0.31, 0.38, 0.38, 0.38, 0.38, 0.38, 0.44, 0.44, 0.44, 0.44, 0.44, 0.46, 0.46, 0.46, 0.46, 0.46, 0.37, 0.37, 0.37, 0.37, 0.37, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.26, 0.26, 0.26, 0.26, 0.26, 0.29, 0.29, 0.29, 0.29, 0.29, 0.25, 0.25, 0.25, 0.25, 0.25, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.24, 0.24, 0.24, 0.24, 0.24, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.24, 0.24, 0.24, 0.24, 0.24, 0.25, 0.25, 0.25, 0.25, 0.25, 0.23, 0.23, 0.23, 0.23, 0.23, 0.28, 0.28, 0.28, 0.28, 0.28, 0.26, 0.26, 0.26, 0.26, 0.26, 0.27, 0.27, 0.27, 0.27, 0.27, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.27, 0.27, 0.27, 0.27, 0.27, 0.3, 0.3, 0.3, 0.3, 0.3, 0.39, 0.39, 0.39, 0.39, 0.39, 0.46, 0.46, 0.46, 0.46, 0.46, 0.44, 0.44, 0.44, 0.44, 0.44, 0.48, 0.48, 0.48, 0.48, 0.48, 0.49, 0.49, 0.49, 0.49, 0.49, 0.32, 0.32, 0.32, 0.32, 0.32, 0.22, 0.22, 0.22, 0.22, 0.22, 0.24, 0.24, 0.24, 0.24, 0.24, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.24, 0.24, 0.24, 0.24, 0.24, 0.34, 0.34, 0.34, 0.34, 0.34, 0.35, 0.35, 0.35, 0.35, 0.35, 0.38, 0.38, 0.38, 0.38, 0.38, 0.35, 0.35, 0.35, 0.35, 0.35, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.23, 0.23, 0.23, 0.23, 0.23, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 206 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1714359527 --> 1714360159
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0]
                    
Loading

Copy link

@p-e-w p-e-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampler logic seems to be correct now AFAICT. Only nits remaining, I'm too unfamiliar with the conventions of this project so I defer the rest to the maintainers.

llama.cpp Outdated
@@ -13233,6 +13233,96 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter order is still inconsistent with definitions above (base, multiplier vs. multiplier, base).

float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
uint32_t dry_allowed_length = 2;
uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unsigned integer shouldn't be set to -1. That C++ even compiles this is crazy.

There shouldn't be two separate ways to disable the sampler. Setting dry_multiplier to 0 already disables it, no need for a second mechanism.

The correct semantics, IMO, are:

  • last_n = 0: The whole context is searched.
  • last_n > 0: The last last_n tokens are searched.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be converted to the maximum value of uint32_t, so ... uh... task failed successfully, I guess. (I will fix this)

setting dry_penalty_last_n=-1 was to keep the same convention as repetition_penalty. I'll update this according to what the maintainer says.

llama.cpp Outdated
@@ -13360,7 +13450,7 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c
const int64_t t_start_sample_us = ggml_time_us();

// no need to do anything if there is only one (or zero) candidates
if(candidates_p->size <= 1) {
if (candidates_p->size <= 1) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following changes are in unrelated code and probably shouldn't be in this PR.

@bakkot
Copy link
Contributor

bakkot commented May 7, 2024

Definitely helps cut down on repetition for me. You can actually see it start to repeat a paragraph and then change course, which is fun.

Would be good to add this to common.cpp too so it's available on the CLI.

@mofosyne mofosyne added enhancement New feature or request Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 9, 2024
@ExtReMLapin
Copy link
Contributor

Alright, my bad, I missed this part.

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 5, 2024

I've been working on it, here is what I had for server.cpp. I tried to implement it as Koboldcpp/Ooba do where either an array of strings or a JSON-encoded array of strings is accepted:

        slot.sparams.penalty_present    = json_value(data, "presence_penalty",   default_sparams.penalty_present);
        slot.sparams.dry_multiplier     = json_value(data, "dry_multiplier",     default_sparams.dry_multiplier);
        slot.sparams.dry_base           = json_value(data, "dry_base",           default_sparams.dry_base);
        slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
        slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
        slot.sparams.mirostat           = json_value(data, "mirostat",           default_sparams.mirostat);
        // .. skipping over some ..
        slot.sparams.min_keep           = json_value(data, "min_keep",           default_sparams.min_keep);

        // sequence breakers for DRY
        {
            auto dry_seq_breakers = data.find("dry_seq_breakers");
            if (dry_seq_breakers != data.end()) {
                try {
                    if (dry_seq_breakers->is_array()) {
                        slot.sparams.dry_seq_breakers = dry_seq_breakers->get<std::vector<std::string>>();
                    } else if (dry_seq_breakers->is_string()) {
                        slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get<std::string>()).get<std::vector<std::string>>();
                    } else {
                        send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST);
                        return false;
                    }
                } catch (const std::exception & e) {
                    send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
                    return false;
                }
            }
        }

I had just started on trying to convert it over to a std::vectorstd::string and was going to ask about potentially excluding sequence breakers as a CLI argument, similar to how stopping strings are excluded. If you want me to continue on anything I can, but I will let you take over otherwise? Hopefully this helps. I appreciate all your efforts on this!

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 5, 2024

While I have used git a good deal for personal projects, I am a complete noob with github so if I screwed something up or am not following proper etiquette with something, please accept my apologies in advance.

So if I did this correctly, I created a PR directly on your fork with my bit of changes that I had made, both in server.cpp and in common.cpp:

Add DRY sampling parameters to gpt_params and server_context #29

wwoodsTM and others added 2 commits August 5, 2024 00:41
Add DRY sampling parameters to gpt_params and server_context
@l3utterfly
Copy link
Contributor Author

While I have used git a good deal for personal projects, I am a complete noob with github so if I screwed something up or am not following proper etiquette with something, please accept my apologies in advance.

So if I did this correctly, I created a PR directly on your fork with my bit of changes that I had made, both in server.cpp and in common.cpp:

Add DRY sampling parameters to gpt_params and server_context #29

Thank you so much for this! It is very helpful.

I have merged your changes

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 6, 2024

No problem! And I realize more work was needed for the vector of strings implementation to actually work and I have been plugging away at that but hit a bug where for whatever reason the JSON settings I submit to llama-server are not coming through for the DRY settings. I have been trying to get this working before I submit again, but let me know @l3utterfly if you would like to see what I have done so far. Otherwise I will keep working on it and submit as soon as I can figure it out unless you have done any work on that already.

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 6, 2024

Ok, I just did a much more significant PR to your fork. I got my attempt at a somewhat optimized version with a vector of strings for the sequence breakers working and an attempt at overflow protection (but not much in the way of error handling currently):

Working implementation of DRY with one key issue I could use help with #30

The one issue, as I mention on there, is that I had to move the "LLAMA_API void llama_sample_dry" function signature in llama.h to the bottom in the C++ "LLAMA_API_INTERNAL" section to avoid getting errors on compile. Obviously maintaining C compatibility in this file seems to be a priority and my use of vectors in the function signature for the API function clearly did not fit with that. My assumption is that this probably breaks certain kinds of compiles or would mess up bindings or the like? Maybe someone can help me fix this one issue if my version of the implementation is felt to be worthwhile.

Additionally, my use of borrowed "llama_tokenize" and "llama_detokenize" overloads directly placed in llama-sampling.cpp probably has a better solution for those more familiar with the overall codebase, but it was a quick and dirty way to get it working.

Finally, I welcome any and all help with testing as my own testing has not been exhaustive.

Working implementation of DRY with one key issue I could use help with
@l3utterfly
Copy link
Contributor Author

@wwoodsTM Thank you! I have merged your changes

@ExtReMLapin
Copy link
Contributor

From Compilade and Ggerganov's TODO list :

Good to go if everything works correctly, can't wait to test this on python-llama-cpp

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 8, 2024

From Compilade and Ggerganov's TODO list :

Good to go if everything works correctly, can't wait to test this on python-llama-cpp

I would also add to that list my request of someone with more knowledge of the codebase checking the llama.h issue I mentioned above. The other things I mentioned were not as major of issues but the llama.h thing probably needs to be addressed as my hunch is that my quick “fix” there probably breaks something.

@ggerganov
Copy link
Owner

Just a heads-up, this will likely get merged after #8643

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Aug 9, 2024

In looking more at this, and keeping in mind the changes coming with #8643, I thought of two possible ideas for dealing with the C interface issue I mentioned before:

A) We could move the sequence breakers to the llama_sampling object and pull them directly from there in the DRY sampling function. If I am understanding the roadmap on all this more or less correctly, this might basically go along decently with the migration of the overall sampling state to llama_sampling (which #8643 conveniently declares in llama.h).

B) Alternatively, we could maybe preprocess the sequence breakers into an array of token arrays, refactor the DRY sampling function to only work with tokens, and maybe have the function signature in llama.h look something like this:

    LLAMA_API void llama_sample_dry(
           struct llama_sampling  *  smpl,
          llama_token_data_array  *  candidates,
               const llama_token  *  last_tokens,
                          size_t     last_tokens_size,
                           float     dry_base,
                           float     dry_multiplier,
                             int     dry_allowed_length,
               const llama_token * * dry_seq_breakers,
                    const size_t  *  dry_seq_breaker_lens,
                          size_t     dry_seq_breaker_count);

@l3utterfly @ggerganov @p-e-w @ExtReMLapin @belladoreai Any thoughts? I would be happy to try experimenting with one of these ideas or any other suggestions in my own #8643 sandbox to try get a head start, but am hoping to get some guidance on this particular issue.

@ggerganov
Copy link
Owner

Yes, I think llama_sampling would be suitable to store things like sequence breakers. My immediate plan for #8643 would be to move everything remaining from current common/sampling into llama_sampling

@ddh0
Copy link
Contributor

ddh0 commented Sep 11, 2024

Yes, I think llama_sampling would be suitable to store things like sequence breakers. My immediate plan for #8643 would be to move everything remaining from current common/sampling into llama_sampling

Now that #9294 is complete, is there a roadmap for DRY to be merged into mainline?

@wwoodsTM
Copy link
Contributor

In retrospect, I think @ggerganov's previous comment about this being merged after #9294 was part vote-of-confidence and part warning, as he knew the changes that #9294 was going to introduce would likely require some significant rewriting of this PR outside of the core DRY sampling functionality.

When I tried working on this part-way through the completion of #9294, and was looking at how extensive the changes to the overall sampling code are, I quickly realized that at least for a very rusty C++ coder like myself, it would be much wiser to just wait for the refactor to be complete before trying to do anything.

Now that it is complete, I am going to start digging into this again. I welcome any help, especially from those with more experience with the overall code-base.

@wwoodsTM
Copy link
Contributor

Ok, I’m happy to say I have something really solid at this point! 😅 It’s basically complete, though I noticed some underlying issues with the sampler state not being maintained in the original refactor. My current solution "works," but it may be more of a bandaid than a true fix—definitely something to keep an eye on.

Before I proceed, I wanted to check with you, @l3utterfly: Should I submit this as a PR to this branch directly, as I did before, or would it be better to submit a brand new PR directly to ggerganov/master and list you as a co-author?

I assume the conflicts might be quite messy with how different the branches are now, but I know there are ways to handle that while preserving the commit history. What do you think would be the best route in this case?

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Oct 1, 2024

I went ahead and made a separate PR (#9702) and listed you as co-author @l3utterfly. If you think it is better for me to undo that I can, as maybe I am being overeager lol, but I have been really looking forward to getting this PR to the finish line... 😄

@ExtReMLapin
Copy link
Contributor

can be closed I guess as the other one was merged

@l3utterfly l3utterfly closed this Nov 5, 2024
@l3utterfly l3utterfly deleted the dry-sampler branch November 5, 2024 07:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request examples Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level server
Projects
None yet
Development

Successfully merging this pull request may close these issues.