Skip to content

Commit

Permalink
Merge branch 'master' of github.com:scottwey/mistral.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
scottwey committed Sep 15, 2024
2 parents c6f1845 + 8928c83 commit c8c727a
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ impl Engine {
) -> Self {
let device = get_mut_arcmutex!(pipeline).device().clone();
let is_xlora = get_mut_arcmutex!(pipeline).get_metadata().is_xlora;
let has_no_kv_cache = get_mut_arcmutex!(pipeline).get_metadata().has_no_kv_cache;
assert_eq!(has_no_kv_cache, no_kv_cache);
// Prefix caching is always disabled if using PagedAttention for now.
// TODO
let no_prefix_cache =
matches!(config, SchedulerConfig::PagedAttentionMeta { .. }) || no_prefix_cache;
let no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
|| no_prefix_cache
|| has_no_kv_cache;
Self {
rx,
pipeline,
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
Some(model_id),
)
.with_no_kv_cache(args.no_kv_cache)
.build(arch)?,
ModelSelected::XLora {
model_id,
Expand All @@ -154,6 +155,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
model_id,
)
.with_no_kv_cache(args.no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -183,6 +185,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
model_id,
)
.with_no_kv_cache(args.no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -231,6 +234,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
topology: Topology::from_option_path(topology)?,
},
)
.with_no_kv_cache(args.no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -261,6 +265,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
topology: Topology::from_option_path(topology)?,
},
)
.with_no_kv_cache(args.no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -288,6 +293,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(args.no_kv_cache)
.build(),
ModelSelected::XLoraGGML {
tok_model_id,
Expand All @@ -311,6 +317,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(args.no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -342,6 +349,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(args.no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand Down
7 changes: 7 additions & 0 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pub struct GGMLLoaderBuilder {
}

impl GGMLLoaderBuilder {
/// NOTE: Until v0.4.0, you should make sure to call `.with_no_kv_cache` if applicable.
pub fn new(
config: GGMLSpecificConfig,
chat_template: Option<String>,
Expand All @@ -124,6 +125,12 @@ impl GGMLLoaderBuilder {
}
}

// TODO(EricLBuehler): in 0.4.0 we can move this into the config
pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
self.no_kv_cache = no_kv_cache;
self
}

fn with_adapter(
mut self,
xlora_model_id: String,
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ impl GGUFLoaderBuilder {
/// Create a loader builder for a GGUF model. `tok_model_id` is the model ID where you can find a
/// `tokenizer_config.json` file. If the `chat_template` is specified, then it will be treated as a
/// path and used over remote files, removing all remote accesses.
///
/// NOTE: Until v0.4.0, you should make sure to call `.with_no_kv_cache` if applicable.
pub fn new(
chat_template: Option<String>,
tok_model_id: Option<String>,
Expand All @@ -137,6 +139,12 @@ impl GGUFLoaderBuilder {
}
}

// TODO(EricLBuehler): in 0.4.0 we can move this into the config
pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
self.no_kv_cache = no_kv_cache;
self
}

fn with_adapter(
mut self,
xlora_model_id: String,
Expand Down
7 changes: 7 additions & 0 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub struct NormalSpecificConfig {
}

impl NormalLoaderBuilder {
/// NOTE: Until v0.4.0, you should make sure to call `.with_no_kv_cache` if applicable.
pub fn new(
config: NormalSpecificConfig,
chat_template: Option<String>,
Expand All @@ -115,6 +116,12 @@ impl NormalLoaderBuilder {
}
}

// TODO(EricLBuehler): in 0.4.0 we can move this into the config
pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
self.no_kv_cache = no_kv_cache;
self
}

fn with_adapter(
mut self,
xlora_model_id: String,
Expand Down
9 changes: 9 additions & 0 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ fn parse_which(
tokenizer_json,
Some(model_id),
)
.with_no_kv_cache(no_kv_cache)
.build(arch.map(Into::into))?,
Which::XLora {
model_id,
Expand All @@ -119,6 +120,7 @@ fn parse_which(
tokenizer_json,
model_id,
)
.with_no_kv_cache(no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -147,6 +149,7 @@ fn parse_which(
tokenizer_json,
model_id,
)
.with_no_kv_cache(no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand All @@ -170,6 +173,7 @@ fn parse_which(
topology: Topology::from_option_path(topology)?,
},
)
.with_no_kv_cache(no_kv_cache)
.build(),
Which::XLoraGGUF {
tok_model_id,
Expand All @@ -189,6 +193,7 @@ fn parse_which(
topology: Topology::from_option_path(topology)?,
},
)
.with_no_kv_cache(no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -216,6 +221,7 @@ fn parse_which(
topology: Topology::from_option_path(topology)?,
},
)
.with_no_kv_cache(no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -243,6 +249,7 @@ fn parse_which(
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(no_kv_cache)
.build(),
Which::XLoraGGML {
tok_model_id,
Expand All @@ -266,6 +273,7 @@ fn parse_which(
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(no_kv_cache)
.with_xlora(
xlora_model_id,
serde_json::from_reader(
Expand Down Expand Up @@ -297,6 +305,7 @@ fn parse_which(
quantized_model_id,
quantized_filename,
)
.with_no_kv_cache(no_kv_cache)
.with_lora(
adapters_model_id,
serde_json::from_reader(
Expand Down

0 comments on commit c8c727a

Please sign in to comment.