diff --git a/src/logits_processor.rs b/src/logits_processor.rs index 695538e..fac754d 100644 --- a/src/logits_processor.rs +++ b/src/logits_processor.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, collections::HashMap, iter::zip}; -use candle_core::{bail, DType, Error, Result, Tensor}; +use candle_core::{bail, DType, Device, Error, Result, Tensor, D}; use rand::{ distributions::{Distribution, WeightedIndex}, SeedableRng, @@ -34,17 +34,17 @@ pub enum SamplingMethod { TopKP((usize, f64)), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] // Top-n logprobs element pub struct TopLogprob { - pub token: u32, + pub token: Tensor, pub logprob: f32, pub bytes: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct Logprobs { - pub token: u32, + pub token: Tensor, pub logprob: f32, pub bytes: String, pub top_logprobs: Vec, @@ -79,7 +79,7 @@ impl LogitsProcessor { } } - fn apply_logit_bias(&self, probs: &mut [f32]) -> Result<()> { + /*fn apply_logit_bias(&self, probs: &mut [f32]) -> Result<()> { if let Some(ref bias) = self.logits_bias { for (id, bias_v) in bias { let idx = probs.get_mut(*id as usize); @@ -383,5 +383,15 @@ impl LogitsProcessor { } }; Ok(next_token) + }*/ + + pub fn sample(&mut self, logits: &Tensor, penalty_ctxt: Option<&Tensor>) -> Result { + let logits = logits.to_dtype(DType::F32)?; + Ok(Logprobs { + token: logits.argmax(D::Minus1)?, + logprob: 0.0, + bytes: "".to_string(), + top_logprobs: vec![], + }) } }