From 5e86b817490dfe26b79f40053c78f850b6c08de0 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 2 Apr 2024 20:40:30 -0400 Subject: [PATCH] Implement for argmax too --- src/logits_processor.rs | 54 +++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/src/logits_processor.rs b/src/logits_processor.rs index aab8137..fb183ce 100644 --- a/src/logits_processor.rs +++ b/src/logits_processor.rs @@ -97,70 +97,56 @@ impl LogitsProcessor { } fn sample_argmax(&mut self, logits: Tensor) -> Result { - let mut logits_v: Vec = logits.to_vec1()?; + let mut probs: Vec = logits.to_vec1()?; + let argsort_indices = (0..probs.len()).collect::>(); - self.apply_logit_bias(&mut logits_v)?; + self.apply_logit_bias(&mut probs)?; - let next_token = logits_v + let next_token = probs .iter() .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i) .unwrap(); - let logprob = logits_v[next_token].log(10.0); - let tok = logits_v[next_token]; - - let mut sorted = logits_v.clone(); - sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); - // Where the next token is in the sorted - let next_token_index = sorted - .binary_search_by(|w| { - if *w <= tok { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_err(); + let logprob = probs[next_token].log(10.0); + + let mut argsort_indices_sorted = argsort_indices.clone(); + // Sort by descending prob + argsort_indices_sorted.sort_by(|a, b| probs[*b].partial_cmp(&probs[*a]).unwrap()); // These are where the top n are let top_n_toks_range = - next_token_index.saturating_sub(self.top_n_logprobs)..next_token_index; + 0..self.top_n_logprobs; + dbg!(&top_n_toks_range); // The top n's values - let top_n_logprobs = sorted[top_n_toks_range] + let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()] .iter() - .map(|x| x.log(10.0)) + .map(|x| probs[*x].log(10.0)) .collect::>(); // Find where they actually are in the logits let mut top_n_toks = Vec::new(); - for val in top_n_logprobs.iter() { - let idx = logits_v - .binary_search_by(|w| { - if *w <= *val { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_err(); - top_n_toks.push(idx as u32); + for val in top_n_toks_range { + top_n_toks.push(argsort_indices[val]); } let mut bytes = Vec::new(); + dbg!(&top_n_toks); for tok in &top_n_toks { bytes.push( self.tokenizer - .decode(&[*tok], true) + .decode(&[*tok as u32], true) .map_err(|x| Error::Msg(x.to_string()))?, ); } let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs)) .map(|(bytes, (token, logprob))| TopLogprob { - token, + token: token as u32, logprob, bytes, }) .collect::>(); + dbg!(next_token); + Ok(Logprobs { token: next_token as u32, logprob,