Skip to content

Commit

Permalink
perplexity : make Winogrande work as it does on master
Browse files Browse the repository at this point in the history
The problems with the Winogrande implementation will
need to be fixed in a separate PR to ease review.
  • Loading branch information
compilade committed Mar 19, 2024
1 parent d04cfaf commit 8f70dcb
Showing 1 changed file with 36 additions and 51 deletions.
87 changes: 36 additions & 51 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,8 @@ struct winogrande_entry {
size_t i_logits;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
size_t n_base2; // number of tokens for context + choice 2
std::vector<llama_token> seq_tokens[2];
};

Expand Down Expand Up @@ -1038,38 +1040,6 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
auto index = line.substr(0, comma_pos[0]);
if ('a' <= sentence[0] && sentence[0] <= 'z') {
// make the first letter a capital letter
sentence[0] -= 'a' - 'A';
}
for (int i = 0; i < (int) sentence.size() - 1; ++i) {
// trim repeated spaces and spaces before punctuation
if (sentence[i] == ' ') {
char next = sentence[i+1];
if (next == ' ' || next == ',' || next == '.' || next == '\'') {
char r[2] = { next, 0 };
sentence.replace(i, 2, r);
--i; // stay at the same index for repeated spaces
}
} else if (sentence[i] == ',' || sentence[i] == '.') {
if (sentence[i] == sentence[i+1]) {
// trim repeated punctuation (forward to work at the end of sentences)
char r[2] = { sentence[i], 0 };
sentence.replace(i, 2, r);
--i; // same index to then run the other checks on that punctuation
} else if (0 < i && sentence[i-1] == sentence[i]) {
// trim repeated punctuation (looks back to work with the space trim)
char r[2] = { sentence[i], 0 };
sentence.replace(i-1, 2, r);
i -= 2; // go back because content was shifted
} else if (sentence[i+1] != ' ') {
// add missing space after punctuation
// (since the loop stops before the end, this adds no trailing space)
char r[3] = { sentence[i], ' ', 0 };
sentence.replace(i, 1, r);
}
}
}
int where = 0;
for ( ; where < int(sentence.size()); ++where) {
if (sentence[where] == '_') break;
Expand Down Expand Up @@ -1106,6 +1076,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
*/
static void winogrande_score(llama_context * ctx, const gpt_params & params) {

constexpr int k_min_trailing_ctx = 3;

auto data = load_winogrande_from_csv(params.prompt);
if (data.empty()) {
fprintf(stderr, "%s: no tasks\n", __func__);
Expand Down Expand Up @@ -1150,11 +1122,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
task.common_prefix++;
}

// TODO: the last token of each of the sequences don't need to be evaluated
task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix
// the last tokens don't need to be evaluated
- 2;
task.seq_tokens[1].size() - task.common_prefix;

task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
}

fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
Expand Down Expand Up @@ -1201,8 +1175,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
n_logits += 1;

for (int s = 0; s < 2; ++s) {
// end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size() - 1; ++i) {
// TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1;
}
Expand Down Expand Up @@ -1234,38 +1208,49 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

// start from the end of the common prefix
size_t li = 0;
for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) {
const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
size_t li = n_base1 - task.common_prefix;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
}
// first token of the second choice is predicted by the end of the common prefix
eval_pairs.emplace_back(task.i_logits, task.seq_tokens[1][task.common_prefix]);
for (size_t j = task.common_prefix; j < task.seq_tokens[1].size()-1; ++j) {
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
// FIXME: this uses the wrong first logits when not skipping the choice word
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
}
if (i < i1 - 1) {
// make sure all logits have been processed as expected
GGML_ASSERT(task.i_logits + li == data[i+1].i_logits);
}
}
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);

size_t ir = 0;
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

float score_1st = 0;
for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) {
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
score_1st += eval_results[ir++];
}
score_1st /= (task.seq_tokens[0].size() - task.common_prefix);
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);

float score_2nd = 0;
for (size_t j = task.common_prefix-1; j < task.seq_tokens[1].size()-1; ++j) {
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
score_2nd += eval_results[ir++];
}
score_2nd /= (task.seq_tokens[1].size() - task.common_prefix);
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);

int result = score_1st > score_2nd ? 1 : 2;

Expand Down

0 comments on commit 8f70dcb

Please sign in to comment.