Skip to content

Commit

Permalink
Fix start of context
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jul 27, 2024
1 parent fe490a6 commit 88f99bb
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 70 deletions.
36 changes: 0 additions & 36 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -19,7 +18,6 @@ use crate::{
scheduler::{Scheduler, SchedulerOutput},
CompletionResponse, RequestMessage, Response, SchedulerConfig, DEBUG,
};
use candle_core::{Device, Result, Tensor};
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use tracing::{info, warn};
Expand Down Expand Up @@ -430,26 +428,6 @@ impl Engine {
Ok(recognizer)
}

fn alloc_logits_bias(&self, logits_bias: Option<HashMap<u32, f32>>) -> Result<Option<Tensor>> {
let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();
let vocab_size = tokenizer.get_vocab_size(true);

match logits_bias {
Some(bias) => {
let mut logits_bias = vec![0.0; vocab_size];
for (k, v) in bias {
logits_bias[k as usize] = v;
}
Ok(Some(Tensor::from_vec(
logits_bias,
vocab_size,
&Device::Cpu,
)?))
}
None => Ok(None),
}
}

async fn handle_request(&mut self, request: Request) {
match request {
Request::ActivateAdapters(adapters) => {
Expand Down Expand Up @@ -644,19 +622,6 @@ impl Engine {
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");

let logits_bias = match self.alloc_logits_bias(request.sampling_params.logits_bias) {
Ok(logits_bias) => logits_bias,
Err(err) => {
request
.response
.send(Response::ValidationError(
format!("Failed creation of logits bias. {}", err).into(),
))
.await
.expect("Expected receiver.");
return;
}
};
let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();

let sampler = Sampler::new(
Expand All @@ -665,7 +630,6 @@ impl Engine {
tokenizer,
request.sampling_params.frequency_penalty,
request.sampling_params.presence_penalty,
logits_bias,
topk,
topp,
minp,
Expand Down
3 changes: 1 addition & 2 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {

// Create several dummy objects for the sequences.
let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
let dummy_sampler =
Sampler::new(None, 0, tokenizer.clone(), None, None, None, -1, 0.0, 0.0);
let dummy_sampler = Sampler::new(None, 0, tokenizer.clone(), None, None, -1, 0.0, 0.0);

let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
1, false, false, 0,
Expand Down
5 changes: 2 additions & 3 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,9 @@ pub async fn sample_sequence(
sample_speculative: bool,
) -> Result<Logprobs> {
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let start_at = seq.get_toks().len().saturating_sub(seq.prompt_tokens());

let sampler = seq.sampler();
let ctx_clone = seq.get_toks()[start_at..].to_vec();
let ctx_clone = seq.get_toks()[seq.prompt_tokens()..].to_vec();
let rng_clone = rng.clone();
let logits_clone = logits.clone();
let first_lobprobs_response = if use_async_pool {
Expand Down Expand Up @@ -316,7 +315,7 @@ pub async fn sample_sequence(
token_set.apply_to(&mut acc);
let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;

let ctx_clone = seq.get_toks()[start_at..].to_vec();
let ctx_clone = seq.get_toks()[seq.prompt_tokens()..].to_vec();
let rng_clone = rng.clone();
let sampler = seq.sampler();
if use_async_pool {
Expand Down
31 changes: 2 additions & 29 deletions mistralrs-core/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ pub struct Sampler {
tokenizer: Arc<Tokenizer>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
logits_bias: Option<Tensor>,
top_k: i64,
top_p: f64,
min_p: f64,
Expand Down Expand Up @@ -100,7 +99,6 @@ impl Sampler {
tokenizer: Arc<Tokenizer>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
logits_bias: Option<Tensor>,
top_k: i64,
top_p: f64,
min_p: f64,
Expand All @@ -116,7 +114,6 @@ impl Sampler {
tokenizer,
frequency_penalty,
presence_penalty,
logits_bias,
top_k,
top_p,
min_p,
Expand Down Expand Up @@ -400,10 +397,6 @@ impl Sampler {
sample_speculative: bool,
) -> Result<Logprobs> {
let logits = self.apply_penalties(logits.to_vec1()?, penalty_ctxt)?;
let logits = match self.logits_bias {
Some(ref bias) => (logits + bias)?,
None => logits,
};
let next_token = if sample_speculative {
match self.temperature {
None => self.sample_speculative_top_kp_min_p(
Expand Down Expand Up @@ -475,17 +468,7 @@ mod tests {
use std::sync::Arc;
use std::sync::Mutex;

let sampler = Sampler::new(
None,
10,
get_tokenizer().into(),
None,
None,
None,
32,
0.1,
0.05,
);
let sampler = Sampler::new(None, 10, get_tokenizer().into(), None, None, 32, 0.1, 0.05);
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler.sample(logits, None, false, rng, false).unwrap();
Expand All @@ -503,17 +486,7 @@ mod tests {
use std::sync::Arc;
use std::sync::Mutex;

let sampler = Sampler::new(
None,
10,
get_tokenizer().into(),
None,
None,
None,
32,
0.1,
0.05,
);
let sampler = Sampler::new(None, 10, get_tokenizer().into(), None, None, 32, 0.1, 0.05);
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler.sample(logits, None, false, rng, true).unwrap();
Expand Down

0 comments on commit 88f99bb

Please sign in to comment.