Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support multiple GGUF files #692

Merged
merged 4 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mistralrs-core/src/gguf/chat_template.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use tracing::info;

use crate::utils::gguf_metadata::ContentMetadata;

use super::Content;

struct PropsGGUFTemplate {
chat_template: Option<String>,
}
Expand All @@ -23,10 +24,12 @@ impl TryFrom<ContentMetadata<'_>> for PropsGGUFTemplate {
}

// Get chat template from GGUF metadata if it exists
pub fn get_gguf_chat_template(content: &Content) -> Result<Option<String>> {
pub fn get_gguf_chat_template<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<Option<String>> {
let metadata = ContentMetadata {
path_prefix: "tokenizer",
metadata: &content.metadata,
metadata: content.get_metadata(),
};
let props = PropsGGUFTemplate::try_from(metadata)?;
if let Some(ref chat_template) = props.chat_template {
Expand Down
172 changes: 172 additions & 0 deletions mistralrs-core/src/gguf/content.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use std::{collections::HashMap, fs};

use anyhow::Context;
use candle_core::{
quantized::{
gguf_file::{self, Value},
QTensor,
},
Device, Result,
};
use indexmap::IndexMap;
use tracing::info;

use crate::DEBUG;

use super::GGUFArchitecture;

fn parse_gguf_value(value: &Value) -> String {
match value {
Value::Array(vs) => vs
.iter()
.map(parse_gguf_value)
.collect::<Vec<String>>()
.join(", "),
Value::Bool(b) => b.to_string(),
Value::F32(x) => x.to_string(),
Value::F64(x) => x.to_string(),
Value::I8(x) => x.to_string(),
Value::I16(x) => x.to_string(),
Value::I32(x) => x.to_string(),
Value::I64(x) => x.to_string(),
Value::String(x) => x.to_string(),
Value::U8(x) => x.to_string(),
Value::U16(x) => x.to_string(),
Value::U32(x) => x.to_string(),
Value::U64(x) => x.to_string(),
}
}

// Internal invariant: contents and readers must be paired.
/// This abstracts the files for a GGUF model and enables multiple files to be used.
pub struct Content<'a, R: std::io::Seek + std::io::Read> {
contents: Vec<gguf_file::Content>,
readers: &'a mut [&'a mut R],
arch: GGUFArchitecture,
all_metadata: HashMap<String, Value>,
}

impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> {
/// Create a `Content` from a set of file readers.
pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> {
let mut contents = Vec::new();
let n_readers = readers.len();
for reader in readers.iter_mut() {
contents.push(gguf_file::Content::read(reader)?);
}
let n_splits = contents
.iter()
.filter_map(|ct| {
ct.metadata
.get("split.count")
.map(|val| val.to_u64().unwrap())
})
.collect::<Vec<_>>();
if n_splits.len() > 1 {
candle_core::bail!("Multiple contents have multiple `split.count` fields");
}
#[allow(clippy::cast_possible_truncation)]
if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
candle_core::bail!("Number of readers does not match the number of splits.");
} else if n_splits.len() == 1 {
info!("Model n splits: {}", n_splits[0]);
}

let mut arch = None;
for ct in &contents {
if !ct.metadata.contains_key("general.architecture") {
continue;
}

arch = Some(
ct.metadata["general.architecture"]
.to_string()
.context("Model metadata should have declared an architecture")
.and_then(GGUFArchitecture::from_value)
.unwrap(),
);
}
let arch = arch.expect("GGUF files must specify `general.architecture`");

let mut all_metadata = HashMap::new();
for content in &contents {
all_metadata.extend(content.metadata.clone())
}

Ok(Self {
contents,
readers,
arch,
all_metadata,
})
}

pub fn arch(&self) -> GGUFArchitecture {
self.arch
}

/// Retrieve a tensor, searching through each content.
pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> {
for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) {
if let Some(tensor_info) = ct.tensor_infos.get(name) {
return tensor_info.read(reader, ct.tensor_data_offset, device);
}
}
candle_core::bail!("Cannot find tensor info for {name}")
}

/// Print metadata for these contents.
/// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled.
pub fn print_metadata(&self) -> anyhow::Result<()> {
// Find the ct with general.architecture
let mut keys = Vec::new();
let mut metadatas = Vec::new();
let mut tensors = Vec::new();
for ct in &self.contents {
keys.extend(ct.metadata.keys());
metadatas.push(&ct.metadata);

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
for (name, info) in &ct.tensor_infos {
tensors.push(format!(
"name = `{name}`, shape = {:?}, dtype = {:?}",
info.shape.clone(),
info.ggml_dtype
));
}
}
}

info!("Model config:");
keys.sort();
let mut output_keys = IndexMap::new();
for name in keys {
if !name.contains("tokenizer") {
for metadata in &metadatas {
if let Some(val) = metadata.get(name) {
output_keys.insert(name, parse_gguf_value(val));
}
}
}
}
for (name, val) in output_keys {
println!("{name}: {val}")
}

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
fs::write(
"mistralrs_gguf_tensors.txt",
serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
)?;

info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}

anyhow::Ok(())
}

/// Get all metadatas
pub fn get_metadata(&self) -> &HashMap<String, Value> {
&self.all_metadata
}
}
9 changes: 6 additions & 3 deletions mistralrs-core/src/gguf/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use std::{collections::HashMap, sync::atomic::Ordering};

use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use itertools::Itertools;
use tokenizers::{
decoders::{
Expand All @@ -23,6 +22,8 @@ use tracing::info;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::DEBUG;

use super::Content;

pub(crate) struct GgufTokenizerConversion {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
Expand Down Expand Up @@ -71,10 +72,12 @@ struct AddedTokensCollection {
unk: Option<String>,
}

pub fn convert_gguf_to_hf_tokenizer(content: &Content) -> Result<GgufTokenizerConversion> {
pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<GgufTokenizerConversion> {
let metadata = ContentMetadata {
path_prefix: "tokenizer.ggml",
metadata: &content.metadata,
metadata: content.get_metadata(),
};
let props = PropsGGUF::try_from(metadata)?;

Expand Down
37 changes: 36 additions & 1 deletion mistralrs-core/src/gguf/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
mod chat_template;
mod content;
mod gguf_tokenizer;
use strum::EnumString;

pub use chat_template::get_gguf_chat_template;
use anyhow::{Context, Result};
pub(crate) use chat_template::get_gguf_chat_template;
pub(crate) use content::Content;
pub(crate) use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion};
use std::str::FromStr;

pub const GGUF_MULTI_FILE_DELIMITER: &str = " ";

#[derive(Debug, EnumString, Clone, Copy)]
#[strum(serialize_all = "kebab-case")]
pub enum GGUFArchitecture {
Llama,
Mpt,
Gptneox,
Gptj,
Gpt2,
Bloom,
Falcon,
Mamba,
Rwkv,
Phi2,
Phi3,
Starcoder2,
}

// Wraps from_str() for some convenience:
// - Case-insensitive variant matching (TODO: is this desirable?)
// - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332
impl GGUFArchitecture {
pub fn from_value<T: AsRef<str> + std::fmt::Display>(value: T) -> Result<Self> {
Self::from_str(&value.as_ref().to_ascii_lowercase())
.with_context(|| format!("Unknown GGUF architecture `{value}`"))
.map_err(anyhow::Error::msg)
}
}
12 changes: 6 additions & 6 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
};

use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
quantized::{QMatMul, QTensor},
DType, Device, IndexOp, Result, Shape, Tensor, D,
};
use candle_nn::{Linear, Module, VarBuilder};
Expand All @@ -22,7 +22,8 @@ use serde::Deserialize;
pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv};
use crate::{
cublaslt::CUBLASLT_HANDLE, models::llama, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16,
cublaslt::CUBLASLT_HANDLE, gguf::Content, models::llama, pipeline::Phi3RopeScaling,
INHIBIT_GEMM_F16,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -612,13 +613,12 @@ pub struct QLinear {

impl QLinear {
pub fn new<R: std::io::Read + std::io::Seek>(
ct: &gguf_file::Content,
r: &mut R,
ct: &mut Content<'_, R>,
name: &str,
device: &Device,
) -> Result<Self> {
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
let b = ct.tensor(r, &format!("{name}.bias"), device)?;
let w = ct.tensor(&format!("{name}.weight"), device)?;
let b = ct.tensor(&format!("{name}.bias"), device)?;
let inner = QMatMul::from_qtensor(w)?;
let bias = b.dequantize(device)?;
Ok(Self {
Expand Down
14 changes: 7 additions & 7 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ mod xlora_models;

pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
pub use device_map::{DeviceLayerMapMetadata, DeviceMapMetadata, LayerDeviceMapper};
pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
pub use mistralrs_quant::IsqType;
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig};
pub use pipeline::{
chat_template::ChatTemplate, parse_isq_value, AnyMoeLoader, AnyMoePipeline, GGMLLoader,
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder,
GemmaLoader, Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
LocalModelPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
VisionModelLoader, VisionSpecificConfig,
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GemmaLoader,
Idefics2Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, Starcoder2Loader, TokenSource,
VisionLoader, VisionLoaderBuilder, VisionLoaderType, VisionModelLoader, VisionSpecificConfig,
};
pub use request::{Constraint, MessageContent, NormalRequest, Request, RequestMessage};
pub use response::Response;
Expand Down
17 changes: 13 additions & 4 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
get_toml_selected_model_dtype,
pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig},
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector,
VisionLoaderBuilder, VisionSpecificConfig,
VisionLoaderBuilder, VisionSpecificConfig, GGUF_MULTI_FILE_DELIMITER,
};

/// A builder for a loader using the selected model.
Expand Down Expand Up @@ -189,7 +189,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.build(),
Expand All @@ -204,7 +207,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.with_xlora(
Expand All @@ -227,7 +233,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
args.chat_template,
tok_model_id,
quantized_model_id,
quantized_filename,
quantized_filename
.split(GGUF_MULTI_FILE_DELIMITER)
.map(ToOwned::to_owned)
.collect::<Vec<_>>(),
args.prompt_batchsize,
)
.with_lora(
Expand Down
Loading
Loading