diff --git a/Cargo.toml b/Cargo.toml index 82769b0..0239bfb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,5 @@ edition = "2021" candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } rand = "0.8.5" +serde = "1.0.193" +tokenizers = "0.15.0" diff --git a/src/logits_processor.rs b/src/logits_processor.rs index 779b9cb..2672991 100644 --- a/src/logits_processor.rs +++ b/src/logits_processor.rs @@ -1,11 +1,17 @@ +use std::iter::zip; + use candle_core::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; +use serde::{Deserialize, Serialize}; +use tokenizers::Tokenizer; /// LogitsProcessor for sampling. -pub struct LogitsProcessor { +pub struct LogitsProcessor<'a> { rng: rand::rngs::StdRng, temperature: Option, sampling_method: SamplingMethod, + top_n_logprobs: usize, + tokenizer: &'a Tokenizer, } /// Sampling method for `LogitsProcessor`. @@ -20,8 +26,30 @@ pub enum SamplingMethod { TopK(usize), } -impl LogitsProcessor { - pub fn new(seed: u64, temperature: Option, sampling_method: SamplingMethod) -> Self { +#[derive(Debug, Clone, Serialize, Deserialize)] +// Top-n logprobs element +pub struct TopLogprob { + token: usize, + logprob: f32, + bytes: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Logprobs { + token: usize, + logprob: f32, + bytes: String, + top_logprobs: Vec, +} + +impl<'a> LogitsProcessor<'a> { + pub fn new( + seed: u64, + temperature: Option, + sampling_method: SamplingMethod, + top_n_logprobs: usize, + tokenizer: &'a Tokenizer, + ) -> Self { let temperature = if temperature.map_or(true, |v| v < 1e-7) { None } else { @@ -31,27 +59,107 @@ impl LogitsProcessor { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, sampling_method, + top_n_logprobs, + tokenizer, } } - fn sample_argmax(&mut self, logits: Tensor) -> Result { + fn sample_argmax(&mut self, logits: Tensor) -> Result { let logits_v: Vec = logits.to_vec1()?; let next_token = logits_v .iter() .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) + .map(|(i, _)| i) .unwrap(); - Ok(next_token) + let logprob = logits_v[next_token].log(10.0); + + let mut sorted = logits_v.clone(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let top_n_toks_range = next_token + 1 + ..if next_token + 1 + self.top_n_logprobs <= logits_v.len() { + next_token + 1 + self.top_n_logprobs + } else { + logits_v.len() + }; + let top_n_toks = top_n_toks_range.clone().collect::>(); + let top_n_logprobs = sorted[top_n_toks_range] + .iter() + .map(|x| x.log(10.0)) + .collect::>(); + let mut bytes = Vec::new(); + for tok in &top_n_toks { + bytes.push( + self.tokenizer + .decode(&[(*tok).try_into().unwrap()], true) + .map_err(|x| Error::Msg(x.to_string()))?, + ); + } + let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs)) + .map(|(bytes, (token, logprob))| TopLogprob { + token, + logprob, + bytes, + }) + .collect::>(); + + Ok(Logprobs { + token: next_token, + logprob, + top_logprobs, + bytes: self + .tokenizer + .decode(&[next_token.try_into().unwrap()], true) + .map_err(|x| Error::Msg(x.to_string()))?, + }) } - fn sample_multinomial(&mut self, probs: &Vec) -> Result { + fn sample_multinomial(&mut self, probs: &Vec) -> Result { let distr = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?; - let next_token = distr.sample(&mut self.rng) as u32; - Ok(next_token) + let next_token = distr.sample(&mut self.rng); + let logprob = probs[next_token].log(10.0); + + let mut sorted = probs.clone(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let top_n_toks_range = next_token + 1 + ..if next_token + 1 + self.top_n_logprobs <= probs.len() { + next_token + 1 + self.top_n_logprobs + } else { + probs.len() + }; + let top_n_toks = top_n_toks_range.clone().collect::>(); + let top_n_logprobs = sorted[top_n_toks_range] + .iter() + .map(|x| x.log(10.0)) + .collect::>(); + let mut bytes = Vec::new(); + for tok in &top_n_toks { + bytes.push( + self.tokenizer + .decode(&[(*tok).try_into().unwrap()], true) + .map_err(|x| Error::Msg(x.to_string()))?, + ); + } + let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs)) + .map(|(bytes, (token, logprob))| TopLogprob { + token, + logprob, + bytes, + }) + .collect::>(); + + Ok(Logprobs { + token: next_token, + logprob, + top_logprobs, + bytes: self + .tokenizer + .decode(&[next_token.try_into().unwrap()], true) + .map_err(|x| Error::Msg(x.to_string()))?, + }) } - fn sample_topp(&mut self, probs: &mut Vec, top_p: f32) -> Result { + fn sample_topp(&mut self, probs: &mut Vec, top_p: f32) -> Result { // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability top_p. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". @@ -74,7 +182,7 @@ impl LogitsProcessor { self.sample_multinomial(probs) } - fn sample_topk(&mut self, probs: &mut Vec, top_k: usize) -> Result { + fn sample_topk(&mut self, probs: &mut Vec, top_k: usize) -> Result { // Sort probs into descending order (highest probs first) probs.sort_by(|x, y| x.total_cmp(y)); probs.reverse(); @@ -94,7 +202,7 @@ impl LogitsProcessor { /// /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used. /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used. - pub fn sample(&mut self, logits: &Tensor) -> Result { + pub fn sample(&mut self, logits: &Tensor) -> Result { let logits = logits.to_dtype(DType::F32)?; let next_token = match self.temperature { None => self.sample_argmax(logits)?,