Skip to content

Commit

Permalink
Improve penalty context window calculation (#636)
Browse files Browse the repository at this point in the history
* Remove necessity of repeat_last_n

* Fix start of context
  • Loading branch information
EricLBuehler committed Jul 27, 2024
1 parent 874fe4b commit 5e8fe09
Show file tree
Hide file tree
Showing 40 changed files with 65 additions and 362 deletions.
1 change: 0 additions & 1 deletion docs/ANYMOE.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
let loader = NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Expand Down
1 change: 0 additions & 1 deletion docs/IDEFICS2.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
let loader = VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Expand Down
1 change: 0 additions & 1 deletion docs/LLaVA.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
let loader = VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Expand Down
1 change: 0 additions & 1 deletion docs/PAGED_ATTENTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
let loader = NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Expand Down
1 change: 0 additions & 1 deletion docs/PHI3V.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
let loader = VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Expand Down
38 changes: 2 additions & 36 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -19,7 +18,6 @@ use crate::{
scheduler::{Scheduler, SchedulerOutput},
CompletionResponse, RequestMessage, Response, SchedulerConfig, DEBUG,
};
use candle_core::{Device, Result, Tensor};
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use tracing::{info, warn};
Expand Down Expand Up @@ -430,26 +428,6 @@ impl Engine {
Ok(recognizer)
}

fn alloc_logits_bias(&self, logits_bias: Option<HashMap<u32, f32>>) -> Result<Option<Tensor>> {
let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();
let vocab_size = tokenizer.get_vocab_size(true);

match logits_bias {
Some(bias) => {
let mut logits_bias = vec![0.0; vocab_size];
for (k, v) in bias {
logits_bias[k as usize] = v;
}
Ok(Some(Tensor::from_vec(
logits_bias,
vocab_size,
&Device::Cpu,
)?))
}
None => Ok(None),
}
}

async fn handle_request(&mut self, request: Request) {
match request {
Request::ActivateAdapters(adapters) => {
Expand Down Expand Up @@ -644,19 +622,6 @@ impl Engine {
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");

let logits_bias = match self.alloc_logits_bias(request.sampling_params.logits_bias) {
Ok(logits_bias) => logits_bias,
Err(err) => {
request
.response
.send(Response::ValidationError(
format!("Failed creation of logits bias. {}", err).into(),
))
.await
.expect("Expected receiver.");
return;
}
};
let tokenizer = get_mut_arcmutex!(self.pipeline).tokenizer();

let sampler = Sampler::new(
Expand All @@ -665,7 +630,6 @@ impl Engine {
tokenizer,
request.sampling_params.frequency_penalty,
request.sampling_params.presence_penalty,
logits_bias,
topk,
topp,
minp,
Expand Down Expand Up @@ -703,6 +667,7 @@ impl Engine {
.cache_config
.clone()
.map(|conf| conf.block_size);
let trie = (*get_mut_arcmutex!(self.pipeline).get_metadata().tok_trie).clone();
let seq = Sequence::new_waiting(
prompt.clone(),
self.id,
Expand Down Expand Up @@ -733,6 +698,7 @@ impl Engine {
request.adapters.clone(),
images.clone(),
block_size,
trie,
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
seq.prefill(
Expand Down
13 changes: 6 additions & 7 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,12 @@ pub use device_map::{DeviceLayerMapMetadata, DeviceMapMetadata, LayerDeviceMappe
pub use paged_attention::PagedAttentionConfig;
pub use pipeline::{
chat_template::ChatTemplate, AnyMoeLoader, AnyMoePipeline, GGMLLoader, GGMLLoaderBuilder,
GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
GemmaLoader, Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
VisionModelLoader, VisionSpecificConfig,
GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GemmaLoader,
Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, Starcoder2Loader, TokenSource,
VisionLoader, VisionLoaderBuilder, VisionLoaderType, VisionModelLoader, VisionSpecificConfig,
};
pub use request::{Constraint, MessageContent, NormalRequest, Request, RequestMessage};
pub use response::Response;
Expand Down
44 changes: 8 additions & 36 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use std::fs::{self, File};

use crate::{
get_toml_selected_model_dtype,
pipeline::{
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig,
NormalSpecificConfig,
},
pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector,
VisionLoaderBuilder, VisionSpecificConfig,
};
Expand Down Expand Up @@ -110,15 +107,11 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
}
ModelSelected::Plain {
model_id,
repeat_last_n,
tokenizer_json,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
repeat_last_n,
},
NormalSpecificConfig { use_flash_attn },
args.chat_template,
tokenizer_json,
Some(model_id),
Expand All @@ -127,17 +120,13 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
ModelSelected::XLora {
model_id,
xlora_model_id,
repeat_last_n,
order,
tokenizer_json,
tgt_non_granular_index,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
repeat_last_n,
},
NormalSpecificConfig { use_flash_attn },
args.chat_template,
tokenizer_json,
model_id,
Expand All @@ -156,15 +145,11 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
model_id,
tokenizer_json,
adapters_model_id,
repeat_last_n,
order,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
repeat_last_n,
},
NormalSpecificConfig { use_flash_attn },
args.chat_template,
tokenizer_json,
model_id,
Expand All @@ -181,9 +166,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
repeat_last_n,
} => GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n },
args.chat_template,
tok_model_id,
quantized_model_id,
Expand All @@ -194,12 +177,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
repeat_last_n,
xlora_model_id,
order,
tgt_non_granular_index,
} => GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n },
args.chat_template,
tok_model_id,
quantized_model_id,
Expand All @@ -219,11 +200,9 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
repeat_last_n,
adapters_model_id,
order,
} => GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n },
args.chat_template,
tok_model_id,
quantized_model_id,
Expand All @@ -242,10 +221,9 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
quantized_model_id,
quantized_filename,
repeat_last_n,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { repeat_last_n, gqa },
GGMLSpecificConfig { gqa },
args.chat_template,
tokenizer_json,
Some(tok_model_id),
Expand All @@ -258,13 +236,12 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
quantized_model_id,
quantized_filename,
repeat_last_n,
xlora_model_id,
order,
tgt_non_granular_index,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { repeat_last_n, gqa },
GGMLSpecificConfig { gqa },
args.chat_template,
tokenizer_json,
tok_model_id,
Expand All @@ -286,12 +263,11 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
quantized_model_id,
quantized_filename,
repeat_last_n,
adapters_model_id,
order,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { repeat_last_n, gqa },
GGMLSpecificConfig { gqa },
args.chat_template,
tokenizer_json,
tok_model_id,
Expand All @@ -308,15 +284,11 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
.build(),
ModelSelected::VisionPlain {
model_id,
repeat_last_n,
tokenizer_json,
arch,
dtype: _,
} => VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn,
repeat_last_n,
},
VisionSpecificConfig { use_flash_attn },
args.chat_template,
tokenizer_json,
Some(model_id),
Expand Down
Loading

0 comments on commit 5e8fe09

Please sign in to comment.