-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added implementation of DRY sampler #6839
Changes from 8 commits
f64dea0
aea4ad0
4d603e3
75beda2
85dadac
793e1e2
3caec6b
49e078f
2f9a36a
802ddd7
12bfa78
0229fc8
236da59
e862def
9105cf4
20dc562
d1676a1
ed6b909
6579e64
a18fb2f
190898a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13233,6 +13233,96 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can | |
} | ||
} | ||
|
||
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parameter order is still inconsistent with definitions above (base, multiplier vs. multiplier, base). |
||
// skip dry sampler if we don't have a previous token | ||
if (last_tokens_size < 1) return; | ||
|
||
// get the last token | ||
auto last_token = last_tokens[last_tokens_size - 1]; | ||
|
||
// if last token is part of the sequence breakers, skip whole sampler | ||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { | ||
return; | ||
} | ||
|
||
// create an unordered map of "next tokens" <-> max match length | ||
std::unordered_map<llama_token, size_t> match_lengths; | ||
|
||
// loop through each previous token (exclude the last token) | ||
for (size_t i = 0; i < last_tokens_size - 1; ++i) { | ||
// skip if the compare token is not the same as the last token | ||
if (last_tokens[i] != last_token) { | ||
continue; | ||
} | ||
|
||
// get the next token (i + 1 is always less than last_tokens_size) | ||
auto next_token = last_tokens[i + 1]; | ||
l3utterfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// if next token is part of the sequence breakers, skip | ||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { | ||
continue; | ||
} | ||
|
||
// try to extend the match backwards (match length starts at 1 because last token is already matched) | ||
size_t match_length = 1; | ||
|
||
// loop through the previous tokens | ||
for (;; match_length++) { | ||
// if we have reached the start of our last tokens, break | ||
if (i < match_length) break; | ||
|
||
// compare token starts at our prev index, going backwards by match length | ||
auto compare_token = last_tokens[i - match_length]; | ||
|
||
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself | ||
auto head_token = last_tokens[last_tokens_size - 1 - match_length]; | ||
|
||
// break out of the match if any tokens don't match | ||
if (compare_token != head_token) { | ||
break; | ||
l3utterfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// if compare token is part of the sequence breakers, break out of the match | ||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { | ||
break; | ||
} | ||
} | ||
|
||
// Check if the next token exists in the map | ||
auto it = match_lengths.find(next_token); | ||
|
||
if (it == match_lengths.end()) { | ||
// Key does not exist, insert the new value | ||
match_lengths[next_token] = match_length; | ||
} else { | ||
// Key exists, update it with the max of the new value or the existing value | ||
it->second = std::max(it->second, match_length); | ||
} | ||
} | ||
|
||
// apply penalties | ||
for (const auto& pair : match_lengths) { | ||
auto next_token = pair.first; | ||
auto match_length = pair.second; | ||
|
||
// if the match length is greater than or equal to our allowed length in config, we apply penalities | ||
if (match_length >= dry_allowed_length) { | ||
|
||
// find our next token in the candidates->data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't the If this isn't true for llama.cpp, how are the candidates ordered? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked through the creation of We can check for that condition here? But I cannot determine if the candidates are guaranteed to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, I guess the purpose of sorting by logit is to simplify truncation samplers. Probably best to keep the current code then. There are of course possible optimizations (such as interchanging the two loops and deleting tokens from |
||
for (size_t i = 0; i < candidates->size; ++i) { | ||
if (candidates->data[i].id == next_token) { | ||
// calculate the penalty | ||
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); | ||
|
||
// apply the dry penalty | ||
candidates->data[i].logit -= penalty; | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { | ||
if (z >= 1.0f || candidates->size <= 2) { | ||
return; | ||
|
@@ -13360,7 +13450,7 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c | |
const int64_t t_start_sample_us = ggml_time_us(); | ||
|
||
// no need to do anything if there is only one (or zero) candidates | ||
if(candidates_p->size <= 1) { | ||
if (candidates_p->size <= 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following changes are in unrelated code and probably shouldn't be in this PR. |
||
return; | ||
} | ||
|
||
|
@@ -13594,7 +13684,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ | |
t_start_sample_us = ggml_time_us(); | ||
|
||
// Compute error as the difference between observed surprise and target surprise value | ||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
return candidate.id == X; | ||
})); | ||
float observed_surprise = -log2f(candidates->data[X_idx].p); | ||
|
@@ -13616,7 +13706,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok | |
llama_sample_softmax(ctx, candidates); | ||
|
||
// Truncate the words with surprise values greater than mu | ||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
candidates->size = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
return -log2f(candidate.p) > *mu; | ||
})); | ||
|
||
|
@@ -13636,7 +13726,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok | |
t_start_sample_us = ggml_time_us(); | ||
|
||
// Compute error as the difference between observed surprise and target surprise value | ||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { | ||
return candidate.id == X; | ||
})); | ||
float observed_surprise = -log2f(candidates->data[X_idx].p); | ||
|
@@ -15686,7 +15776,7 @@ uint64_t llama_model_n_params(const struct llama_model * model) { | |
} | ||
|
||
struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { | ||
auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(), | ||
auto it = std::find_if (model->tensors_by_name.begin(), model->tensors_by_name.end(), | ||
[name](const std::pair<std::string, struct ggml_tensor *> & it) { | ||
return it.first == name; | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An unsigned integer shouldn't be set to
-1
. That C++ even compiles this is crazy.There shouldn't be two separate ways to disable the sampler. Setting
dry_multiplier
to0
already disables it, no need for a second mechanism.The correct semantics, IMO, are:
last_n = 0
: The whole context is searched.last_n > 0
: The lastlast_n
tokens are searched.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be converted to the maximum value of
uint32_t
, so ... uh... task failed successfully, I guess. (I will fix this)setting
dry_penalty_last_n=-1
was to keep the same convention asrepetition_penalty
. I'll update this according to what the maintainer says.