diff --git a/src/logits_processor.rs b/src/logits_processor.rs index 23ad6ea..476d102 100644 --- a/src/logits_processor.rs +++ b/src/logits_processor.rs @@ -283,14 +283,15 @@ impl LogitsProcessor { } fn sample_topkp(&mut self, probs: &mut Vec, top_k: usize, top_p: f32) -> Result { - // Sort probs into descending order (highest probs first) - probs.sort_by(|x, y| x.total_cmp(y)); + let mut argsort_indices = (0..probs.len()).collect::>(); + + // Sort by descending probability. + argsort_indices.sort_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap()); - // TOP K // Clamp smaller probabilities to zero. - for (index, val) in probs.iter_mut().enumerate() { + for (index, val) in argsort_indices.iter().enumerate() { if index >= top_k { - *val = 0.0; + probs[*val] = 0.0; } }