Skip to content

Commit

Permalink
Fix topk only impl
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 2, 2024
1 parent 3d5b552 commit 6bccd29
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/logits_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ impl LogitsProcessor {
}

fn sample_topk(&mut self, probs: &mut Vec<f32>, top_k: usize) -> Result<Logprobs> {
// Sort probs into descending order (highest probs first)
probs.sort_by(|x, y| x.total_cmp(y));
let mut argsort_indices = (0..probs.len()).collect::<Vec<_>>();

// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap());

// Clamp smaller probabilities to zero.
for (index, val) in probs.iter_mut().enumerate() {
for (index, val) in argsort_indices.iter().enumerate() {
if index >= top_k {
*val = 0.0;
probs[*val] = 0.0;
}
}

Expand Down

0 comments on commit 6bccd29

Please sign in to comment.