Skip to content

Commit

Permalink
Return logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 17, 2023
1 parent 49baa5f commit 85d905f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 12 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ edition = "2021"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
rand = "0.8.5"
serde = "1.0.193"
tokenizers = "0.15.0"
132 changes: 120 additions & 12 deletions src/logits_processor.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::iter::zip;

use candle_core::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;

/// LogitsProcessor for sampling.
pub struct LogitsProcessor {
pub struct LogitsProcessor<'a> {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
sampling_method: SamplingMethod,
top_n_logprobs: usize,
tokenizer: &'a Tokenizer,
}

/// Sampling method for `LogitsProcessor`.
Expand All @@ -20,8 +26,30 @@ pub enum SamplingMethod {
TopK(usize),
}

impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>, sampling_method: SamplingMethod) -> Self {
#[derive(Debug, Clone, Serialize, Deserialize)]
// Top-n logprobs element
pub struct TopLogprob {
token: usize,
logprob: f32,
bytes: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Logprobs {
token: usize,
logprob: f32,
bytes: String,
top_logprobs: Vec<TopLogprob>,
}

impl<'a> LogitsProcessor<'a> {
pub fn new(
seed: u64,
temperature: Option<f64>,
sampling_method: SamplingMethod,
top_n_logprobs: usize,
tokenizer: &'a Tokenizer,
) -> Self {
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
None
} else {
Expand All @@ -31,27 +59,107 @@ impl LogitsProcessor {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
sampling_method,
top_n_logprobs,
tokenizer,
}
}

fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
fn sample_argmax(&mut self, logits: Tensor) -> Result<Logprobs> {
let logits_v: Vec<f32> = logits.to_vec1()?;
let next_token = logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.map(|(i, _)| i)
.unwrap();
Ok(next_token)
let logprob = logits_v[next_token].log(10.0);

let mut sorted = logits_v.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let top_n_toks_range = next_token + 1
..if next_token + 1 + self.top_n_logprobs <= logits_v.len() {
next_token + 1 + self.top_n_logprobs
} else {
logits_v.len()
};
let top_n_toks = top_n_toks_range.clone().collect::<Vec<_>>();
let top_n_logprobs = sorted[top_n_toks_range]
.iter()
.map(|x| x.log(10.0))
.collect::<Vec<_>>();
let mut bytes = Vec::new();
for tok in &top_n_toks {
bytes.push(
self.tokenizer
.decode(&[(*tok).try_into().unwrap()], true)
.map_err(|x| Error::Msg(x.to_string()))?,
);
}
let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs))
.map(|(bytes, (token, logprob))| TopLogprob {
token,
logprob,
bytes,
})
.collect::<Vec<_>>();

Ok(Logprobs {
token: next_token,
logprob,
top_logprobs,
bytes: self
.tokenizer
.decode(&[next_token.try_into().unwrap()], true)
.map_err(|x| Error::Msg(x.to_string()))?,
})
}

fn sample_multinomial(&mut self, probs: &Vec<f32>) -> Result<u32> {
fn sample_multinomial(&mut self, probs: &Vec<f32>) -> Result<Logprobs> {
let distr = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token)
let next_token = distr.sample(&mut self.rng);
let logprob = probs[next_token].log(10.0);

let mut sorted = probs.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let top_n_toks_range = next_token + 1
..if next_token + 1 + self.top_n_logprobs <= probs.len() {
next_token + 1 + self.top_n_logprobs
} else {
probs.len()
};
let top_n_toks = top_n_toks_range.clone().collect::<Vec<_>>();
let top_n_logprobs = sorted[top_n_toks_range]
.iter()
.map(|x| x.log(10.0))
.collect::<Vec<_>>();
let mut bytes = Vec::new();
for tok in &top_n_toks {
bytes.push(
self.tokenizer
.decode(&[(*tok).try_into().unwrap()], true)
.map_err(|x| Error::Msg(x.to_string()))?,
);
}
let top_logprobs = zip(bytes, zip(top_n_toks, top_n_logprobs))
.map(|(bytes, (token, logprob))| TopLogprob {
token,
logprob,
bytes,
})
.collect::<Vec<_>>();

Ok(Logprobs {
token: next_token,
logprob,
top_logprobs,
bytes: self
.tokenizer
.decode(&[next_token.try_into().unwrap()], true)
.map_err(|x| Error::Msg(x.to_string()))?,
})
}

fn sample_topp(&mut self, probs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
fn sample_topp(&mut self, probs: &mut Vec<f32>, top_p: f32) -> Result<Logprobs> {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability top_p. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
Expand All @@ -74,7 +182,7 @@ impl LogitsProcessor {
self.sample_multinomial(probs)
}

fn sample_topk(&mut self, probs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
fn sample_topk(&mut self, probs: &mut Vec<f32>, top_k: usize) -> Result<Logprobs> {
// Sort probs into descending order (highest probs first)
probs.sort_by(|x, y| x.total_cmp(y));
probs.reverse();
Expand All @@ -94,7 +202,7 @@ impl LogitsProcessor {
///
/// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
/// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
pub fn sample(&mut self, logits: &Tensor) -> Result<Logprobs> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = match self.temperature {
None => self.sample_argmax(logits)?,
Expand Down

0 comments on commit 85d905f

Please sign in to comment.