Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor deserialize_chat_template #354

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub struct ChatTemplate {
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
pub bos_token: Option<Bos>,

/// Jinja format chat templating for chat completion.
/// See: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option<String>,
clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>,
Expand Down
16 changes: 7 additions & 9 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::calculate_eos_tokens;
use crate::pipeline::Cache;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{
deserialize_chat_template, do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG,
};
use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG};
use crate::{
models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
xlora_models::XLoraQLlama,
Expand All @@ -26,8 +24,6 @@ use candle_core::quantized::{ggml_file, GgmlDType};
use candle_core::{DType, Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::str::FromStr;
Expand Down Expand Up @@ -339,8 +335,10 @@ impl Loader for GGMLLoader {
};

let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?;

let (chat_template, gen_conf) = deserialize_chat_template!(paths, self);
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand Down
15 changes: 7 additions & 8 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::calculate_eos_tokens;
use crate::pipeline::Cache;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{
deserialize_chat_template, do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG,
};
use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG};
use crate::{
models::quantized_llama::ModelWeights as QLlama,
models::quantized_phi2::ModelWeights as QPhi,
Expand All @@ -32,8 +30,6 @@ use candle_core::quantized::{
use candle_core::{DType, Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::str::FromStr;
Expand Down Expand Up @@ -486,7 +482,10 @@ impl Loader for GGUFLoader {

let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?;

let (chat_template, gen_conf) = deserialize_chat_template!(paths, self);
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand Down
70 changes: 0 additions & 70 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,76 +59,6 @@ macro_rules! api_get_file {
};
}

#[macro_export]
macro_rules! deserialize_chat_template {
($paths:expr, $this:ident) => {{
use tracing::info;

let template: ChatTemplate = serde_json::from_str(&fs::read_to_string(
$paths.get_template_filename(),
)?).unwrap();
let gen_conf: Option<$crate::pipeline::chat_template::GenerationConfig> = $paths.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(
f
).unwrap()).unwrap());
#[derive(Debug, serde::Deserialize)]
struct SpecifiedTemplate {
chat_template: String,
bos_token: Option<String>,
eos_token: Option<String>,
}
match template.chat_template {
Some(_) => (template, gen_conf),
None => {
info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
let mut deser: HashMap<String, Value> =
serde_json::from_str(&fs::read_to_string($paths.get_template_filename())?)
.unwrap();
match $this.chat_template.clone() {
Some(t) => {
if t.ends_with(".json") {
info!("Loading specified loading chat template file at `{t}`.");
let templ: SpecifiedTemplate = serde_json::from_str(&fs::read_to_string(t.clone())?).unwrap();
deser.insert(
"chat_template".to_string(),
Value::String(templ.chat_template),
);
if templ.bos_token.is_some() {
deser.insert(
"bos_token".to_string(),
Value::String(templ.bos_token.unwrap()),
);
}
if templ.eos_token.is_some() {
deser.insert(
"eos_token".to_string(),
Value::String(templ.eos_token.unwrap()),
);
}
info!("Loaded chat template file.");
} else {
deser.insert(
"chat_template".to_string(),
Value::String(t),
);
info!("Loaded specified literal chat template.");
}
},
None => {
info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
deser.insert(
"chat_template".to_string(),
Value::Null,
);
}
};
let ser = serde_json::to_string_pretty(&deser).expect("Serialization of modified chat template failed.");
(serde_json::from_str(&ser).unwrap(), gen_conf)
}
}
}};
}

#[macro_export]
macro_rules! get_paths {
($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{
Expand Down
68 changes: 66 additions & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
use rand_isaac::Isaac64Rng;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use serde_json::Value;
pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
use std::fmt::{Debug, Display};
use std::path::Path;
Expand Down Expand Up @@ -62,15 +63,14 @@
fn get_weight_filenames(&self) -> &[PathBuf];

/// Retrieve the PretrainedConfig file.
/// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig

Check warning on line 66 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

this URL is not a hyperlink
fn get_config_filename(&self) -> &PathBuf;

/// A serialised `tokenizers.Tokenizer` HuggingFace object.
/// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer

Check warning on line 70 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

this URL is not a hyperlink
fn get_tokenizer_filename(&self) -> &PathBuf;

/// Jinja format chat templating for chat completion.
/// See: https://huggingface.co/docs/transformers/chat_templating
/// Content expected to deserialize to [`ChatTemplate`].

Check warning on line 73 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

public documentation for `get_template_filename` links to private item `ChatTemplate`
fn get_template_filename(&self) -> &PathBuf;

/// Optional adapter files. `(String, PathBuf)` is of the form `(id name, path)`.
Expand Down Expand Up @@ -1289,6 +1289,70 @@
}
}

/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
/// If the the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
/// have a `chat_template`, use the provided one.
#[allow(clippy::borrowed_box)]
pub(crate) fn get_chat_template(
paths: &Box<dyn ModelPaths>,
chat_template: &Option<String>,
) -> ChatTemplate {
let template: ChatTemplate =
serde_json::from_str(&fs::read_to_string(paths.get_template_filename()).unwrap()).unwrap();

#[derive(Debug, serde::Deserialize)]
struct SpecifiedTemplate {
chat_template: String,
bos_token: Option<String>,
eos_token: Option<String>,
}

if template.chat_template.is_some() {
return template;
};

info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
let mut deser: HashMap<String, Value> =
serde_json::from_str(&fs::read_to_string(paths.get_template_filename()).unwrap()).unwrap();

match chat_template.clone() {
Some(t) => {
if t.ends_with(".json") {
info!("Loading specified loading chat template file at `{t}`.");
let templ: SpecifiedTemplate =
serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
deser.insert(
"chat_template".to_string(),
Value::String(templ.chat_template),
);
if templ.bos_token.is_some() {
deser.insert(
"bos_token".to_string(),
Value::String(templ.bos_token.unwrap()),
);
}
if templ.eos_token.is_some() {
deser.insert(
"eos_token".to_string(),
Value::String(templ.eos_token.unwrap()),
);
}
info!("Loaded chat template file.");
} else {
deser.insert("chat_template".to_string(), Value::String(t));
info!("Loaded specified literal chat template.");
}
}
None => {
info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
deser.insert("chat_template".to_string(), Value::Null);
}
};
let ser = serde_json::to_string_pretty(&deser)
.expect("Serialization of modified chat template failed.");
serde_json::from_str(&ser).unwrap()
}

mod tests {
#[test]
/// Generating these cases:
Expand Down
16 changes: 8 additions & 8 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,23 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::calculate_eos_tokens;
use crate::pipeline::Cache;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
use crate::xlora_models::NonGranularState;
use crate::{
deserialize_chat_template, do_sample, get_mut_arcmutex, get_paths, lora_model_loader,
normal_model_loader, xlora_model_loader, DeviceMapMetadata, DEBUG,
do_sample, get_mut_arcmutex, get_paths, lora_model_loader, normal_model_loader,
xlora_model_loader, DeviceMapMetadata, DEBUG,
};
use anyhow::Result;
use candle_core::quantized::GgmlDType;
use candle_core::{DType, Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::str::FromStr;
Expand Down Expand Up @@ -296,8 +294,10 @@ impl Loader for NormalLoader {
};

let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?;

let (chat_template, gen_conf) = deserialize_chat_template!(paths, self);
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);

if let Some(in_situ_quant) = in_situ_quant {
model.quantize(in_situ_quant, device.clone())?;
Expand Down
Loading