Skip to content

Commit

Permalink
Add support for multiple GGUF files (#692)
Browse files Browse the repository at this point in the history
* Add multi gguf support

* Allow passing multiple files

* Clippy

* Typo
  • Loading branch information
EricLBuehler authored Aug 18, 2024
1 parent 13f5655 commit 66dba85
Show file tree
Hide file tree
Showing 25 changed files with 453 additions and 347 deletions.
9 changes: 6 additions & 3 deletions mistralrs-core/src/gguf/chat_template.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use tracing::info;

use crate::utils::gguf_metadata::ContentMetadata;

use super::Content;

struct PropsGGUFTemplate {
chat_template: Option<String>,
}
Expand All @@ -23,10 +24,12 @@ impl TryFrom<ContentMetadata<'_>> for PropsGGUFTemplate {
}

// Get chat template from GGUF metadata if it exists
pub fn get_gguf_chat_template(content: &Content) -> Result<Option<String>> {
pub fn get_gguf_chat_template<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<Option<String>> {
let metadata = ContentMetadata {
path_prefix: "tokenizer",
metadata: &content.metadata,
metadata: content.get_metadata(),
};
let props = PropsGGUFTemplate::try_from(metadata)?;
if let Some(ref chat_template) = props.chat_template {
Expand Down
172 changes: 172 additions & 0 deletions mistralrs-core/src/gguf/content.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use std::{collections::HashMap, fs};

use anyhow::Context;
use candle_core::{
quantized::{
gguf_file::{self, Value},
QTensor,
},
Device, Result,
};
use indexmap::IndexMap;
use tracing::info;

use crate::DEBUG;

use super::GGUFArchitecture;

fn parse_gguf_value(value: &Value) -> String {
match value {
Value::Array(vs) => vs
.iter()
.map(parse_gguf_value)
.collect::<Vec<String>>()
.join(", "),
Value::Bool(b) => b.to_string(),
Value::F32(x) => x.to_string(),
Value::F64(x) => x.to_string(),
Value::I8(x) => x.to_string(),
Value::I16(x) => x.to_string(),
Value::I32(x) => x.to_string(),
Value::I64(x) => x.to_string(),
Value::String(x) => x.to_string(),
Value::U8(x) => x.to_string(),
Value::U16(x) => x.to_string(),
Value::U32(x) => x.to_string(),
Value::U64(x) => x.to_string(),
}
}

// Internal invariant: contents and readers must be paired.
/// This abstracts the files for a GGUF model and enables multiple files to be used.
pub struct Content<'a, R: std::io::Seek + std::io::Read> {
contents: Vec<gguf_file::Content>,
readers: &'a mut [&'a mut R],
arch: GGUFArchitecture,
all_metadata: HashMap<String, Value>,
}

impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> {
/// Create a `Content` from a set of file readers.
pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> {
let mut contents = Vec::new();
let n_readers = readers.len();
for reader in readers.iter_mut() {
contents.push(gguf_file::Content::read(reader)?);
}
let n_splits = contents
.iter()
.filter_map(|ct| {
ct.metadata
.get("split.count")
.map(|val| val.to_u64().unwrap())
})
.collect::<Vec<_>>();
if n_splits.len() > 1 {
candle_core::bail!("Multiple contents have multiple `split.count` fields");
}
#[allow(clippy::cast_possible_truncation)]
if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
candle_core::bail!("Number of readers does not match the number of splits.");
} else if n_splits.len() == 1 {
info!("Model n splits: {}", n_splits[0]);
}

let mut arch = None;
for ct in &contents {
if !ct.metadata.contains_key("general.architecture") {
continue;
}

arch = Some(
ct.metadata["general.architecture"]
.to_string()
.context("Model metadata should have declared an architecture")
.and_then(GGUFArchitecture::from_value)
.unwrap(),
);
}
let arch = arch.expect("GGUF files must specify `general.architecture`");

let mut all_metadata = HashMap::new();
for content in &contents {
all_metadata.extend(content.metadata.clone())
}

Ok(Self {
contents,
readers,
arch,
all_metadata,
})
}

pub fn arch(&self) -> GGUFArchitecture {
self.arch
}

/// Retrieve a tensor, searching through each content.
pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> {
for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) {
if let Some(tensor_info) = ct.tensor_infos.get(name) {
return tensor_info.read(reader, ct.tensor_data_offset, device);
}
}
candle_core::bail!("Cannot find tensor info for {name}")
}

/// Print metadata for these contents.
/// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled.
pub fn print_metadata(&self) -> anyhow::Result<()> {
// Find the ct with general.architecture
let mut keys = Vec::new();
let mut metadatas = Vec::new();
let mut tensors = Vec::new();
for ct in &self.contents {
keys.extend(ct.metadata.keys());
metadatas.push(&ct.metadata);

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
for (name, info) in &ct.tensor_infos {
tensors.push(format!(
"name = `{name}`, shape = {:?}, dtype = {:?}",
info.shape.clone(),
info.ggml_dtype
));
}
}
}

info!("Model config:");
keys.sort();
let mut output_keys = IndexMap::new();
for name in keys {
if !name.contains("tokenizer") {
for metadata in &metadatas {
if let Some(val) = metadata.get(name) {
output_keys.insert(name, parse_gguf_value(val));
}
}
}
}
for (name, val) in output_keys {
println!("{name}: {val}")
}

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
fs::write(
"mistralrs_gguf_tensors.txt",
serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
)?;

info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}

anyhow::Ok(())
}

/// Get all metadatas
pub fn get_metadata(&self) -> &HashMap<String, Value> {
&self.all_metadata
}
}
9 changes: 6 additions & 3 deletions mistralrs-core/src/gguf/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use std::{collections::HashMap, sync::atomic::Ordering};

use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use itertools::Itertools;
use tokenizers::{
decoders::{
Expand All @@ -23,6 +22,8 @@ use tracing::info;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::DEBUG;

use super::Content;

pub(crate) struct GgufTokenizerConversion {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
Expand Down Expand Up @@ -71,10 +72,12 @@ struct AddedTokensCollection {
unk: Option<String>,
}

pub fn convert_gguf_to_hf_tokenizer(content: &Content) -> Result<GgufTokenizerConversion> {
pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<GgufTokenizerConversion> {
let metadata = ContentMetadata {
path_prefix: "tokenizer.ggml",
metadata: &content.metadata,
metadata: content.get_metadata(),
};
let props = PropsGGUF::try_from(metadata)?;

Expand Down
37 changes: 36 additions & 1 deletion mistralrs-core/src/gguf/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
mod chat_template;
mod content;
mod gguf_tokenizer;
use strum::EnumString;

pub use chat_template::get_gguf_chat_template;
use anyhow::{Context, Result};
pub(crate) use chat_template::get_gguf_chat_template;
pub(crate) use content::Content;
pub(crate) use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion};
use std::str::FromStr;

pub const GGUF_MULTI_FILE_DELIMITER: &str = " ";

#[derive(Debug, EnumString, Clone, Copy)]
#[strum(serialize_all = "kebab-case")]
pub enum GGUFArchitecture {
Llama,
Mpt,
Gptneox,
Gptj,
Gpt2,
Bloom,
Falcon,
Mamba,
Rwkv,
Phi2,
Phi3,
Starcoder2,
}

// Wraps from_str() for some convenience:
// - Case-insensitive variant matching (TODO: is this desirable?)
// - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332
impl GGUFArchitecture {
pub fn from_value<T: AsRef<str> + std::fmt::Display>(value: T) -> Result<Self> {
Self::from_str(&value.as_ref().to_ascii_lowercase())
.with_context(|| format!("Unknown GGUF architecture `{value}`"))
.map_err(anyhow::Error::msg)
}
}
12 changes: 6 additions & 6 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
};

use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
quantized::{QMatMul, QTensor},
DType, Device, IndexOp, Result, Shape, Tensor, D,
};
use candle_nn::{Linear, Module, VarBuilder};
Expand All @@ -22,7 +22,8 @@ use serde::Deserialize;
pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv};
use crate::{
cublaslt::CUBLASLT_HANDLE, models::llama, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16,
cublaslt::CUBLASLT_HANDLE, gguf::Content, models::llama, pipeline::Phi3RopeScaling,
INHIBIT_GEMM_F16,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -612,13 +613,12 @@ pub struct QLinear {

impl QLinear {
pub fn new<R: std::io::Read + std::io::Seek>(
ct: &gguf_file::Content,
r: &mut R,
ct: &mut Content<'_, R>,
name: &str,
device: &Device,
) -> Result<Self> {
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
let b = ct.tensor(r, &format!("{name}.bias"), device)?;
let w = ct.tensor(&format!("{name}.weight"), device)?;
let b = ct.tensor(&format!("{name}.bias"), device)?;
let inner = QMatMul::from_qtensor(w)?;
let bias = b.dequantize(device)?;
Ok(Self {
Expand Down
14 changes: 7 additions & 7 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ mod xlora_models;

pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
pub use device_map::{DeviceLayerMapMetadata, DeviceMapMetadata, LayerDeviceMapper};
pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
pub use mistralrs_quant::IsqType;
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
pub use pipeline::{
chat_template::ChatTemplate, parse_isq_value, AnyMoeLoader, AnyMoePipeline, GGMLLoader,
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder,
GemmaLoader, Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
VisionModelLoader, VisionSpecificConfig,
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GemmaLoader,
Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, Starcoder2Loader, TokenSource,
VisionLoader, VisionLoaderBuilder, VisionLoaderType, VisionModelLoader, VisionSpecificConfig,
};
pub use request::{Constraint, MessageContent, NormalRequest, Request, RequestMessage};
pub use response::Response;
Expand Down
17 changes: 13 additions & 4 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
get_toml_selected_model_dtype,
pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector,
VisionLoaderBuilder, VisionSpecificConfig,
VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
};

/// A builder for a loader using the selected model.
Expand Down Expand Up @@ -189,7 +189,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.build(),
Expand All @@ -204,7 +207,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.with_xlora(
Expand All @@ -227,7 +233,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.with_lora(
Expand Down
Loading

0 comments on commit 66dba85

Please sign in to comment.