Skip to content

Commit

Permalink
Fix fully topk impl
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 2, 2024
1 parent 6bccd29 commit 15f9eb1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/logits_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,15 @@ impl LogitsProcessor {
}

fn sample_topkp(&mut self, probs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<Logprobs> {
// Sort probs into descending order (highest probs first)
probs.sort_by(|x, y| x.total_cmp(y));
let mut argsort_indices = (0..probs.len()).collect::<Vec<_>>();

// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap());

// TOP K
// Clamp smaller probabilities to zero.
for (index, val) in probs.iter_mut().enumerate() {
for (index, val) in argsort_indices.iter().enumerate() {
if index >= top_k {
*val = 0.0;
probs[*val] = 0.0;
}
}

Expand Down

0 comments on commit 15f9eb1

Please sign in to comment.