diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index ade76a901..35654a99d 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -27,6 +27,15 @@ impl RmsNorm { weight: w, }) } + + pub fn from_w(w: Tensor, eps: f64) -> Result { + let inner = candle_nn::RmsNorm::::new(w.clone(), eps); + Ok(Self { + inner, + eps, + weight: w, + }) + } } impl Module for RmsNorm { diff --git a/mistralrs-core/src/models/mod.rs b/mistralrs-core/src/models/mod.rs index d0a47bcf3..223a6ae4f 100644 --- a/mistralrs-core/src/models/mod.rs +++ b/mistralrs-core/src/models/mod.rs @@ -12,6 +12,7 @@ pub(crate) mod phi2; pub(crate) mod phi3; pub(crate) mod quantized_llama; pub(crate) mod quantized_phi2; +pub(crate) mod quantized_phi3; pub(crate) mod qwen2; pub type LayerCaches = Vec>; diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs new file mode 100644 index 000000000..f1d3117dd --- /dev/null +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -0,0 +1,331 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use std::collections::HashMap; + +use crate::device_map::DeviceMapper; +use crate::layers::RmsNorm; +use crate::DeviceMapMetadata; +use candle_core::quantized::gguf_file; +use candle_core::quantized::QMatMul; +use candle_core::quantized::QTensor; +use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::Embedding; + +use super::repeat_kv; +use super::verify_sanity_gguf; +use super::Cache; + +#[derive(Debug, Clone)] +struct Mlp { + ffn_up: QMatMul, + ffn_down: QMatMul, + i_size: usize, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let up_states = xs.apply(&self.ffn_up)?; + let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; + let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; + let up_states = (up_states * gate.silu()?)?; + up_states.apply(&self.ffn_down) + } +} + +fn rms_norm(w: QTensor, eps: f64) -> Result { + let w = w.dequantize(&w.device())?; + let rms = RmsNorm::from_w(w, eps)?; + Ok(rms) +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attn_qkv: QMatMul, + attn_output: QMatMul, + attn_norm: RmsNorm, + ffn_norm: RmsNorm, + mlp: Mlp, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + sliding_window: usize, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result { + let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?; + let mut outputs = Vec::new(); + for (i, offset) in seqlen_offsets.iter().enumerate() { + let cos = self.cos.narrow(0, *offset, seq_len)?; + let sin = self.sin.narrow(0, *offset, seq_len)?; + outputs.push(candle_nn::rotary_emb::rope( + &xs.i(i)?.unsqueeze(0)?.contiguous()?, + &cos, + &sin, + )?); + } + Tensor::cat(&outputs, 0) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + seqlen_offsets: &[usize], + kv_cache: &mut Option<(Tensor, Tensor)>, + ) -> Result { + let (b_sz, seq_len, n_embd) = x.dims3()?; + let qkv = self.attn_qkv.forward(x)?; + + let query_pos = self.n_head * self.head_dim; + let q = qkv.narrow(D::Minus1, 0, query_pos)?; + let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?; + let v = qkv.narrow( + D::Minus1, + query_pos + self.n_kv_head * self.head_dim, + self.n_kv_head * self.head_dim, + )?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?; + let k = self.apply_rotary_emb(&k, seqlen_offsets)?; + + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = mask.cloned(); + let kv_seq_len = prev_k.dim(2)?; + let sliding_window = self.sliding_window; + if kv_seq_len > sliding_window { + prev_k = + prev_k.narrow(2, kv_seq_len - (sliding_window - 1), sliding_window - 1)?; + prev_v = + prev_v.narrow(2, kv_seq_len - (sliding_window - 1), sliding_window - 1)?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = + mask.narrow(1, mask_len - (sliding_window - 1), sliding_window - 1)?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + let k = Tensor::cat(&[prev_k, k], 2)?; + let v = Tensor::cat(&[prev_v, v], 2)?; + (k, v, mask) + } + }; + *kv_cache = Some((k.clone(), v.clone())); + + let k = repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match attn_mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attn_output.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug)] +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec, + output_norm: RmsNorm, + output: QMatMul, + masks: HashMap, + mapper: Option>, + pub device: Device, + pub cache: Cache, + pub max_seq_len: usize, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, + context_window: usize, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, context_window as u32, device)? + .to_dtype(DType::F32)? + .reshape((context_window, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + mapper: DeviceMapMetadata, + ) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle_core::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?; + + // Parameter extraction from metadata. + let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("phi3.block_count")?.to_u32()? as usize; + let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; + let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let context_window = md_get("phi3.context_length")?.to_u32()? as usize; + let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; + let output = QMatMul::from_qtensor(ct.tensor(reader, "output.weight", device)?)?; + let mut layers = Vec::with_capacity(block_count); + let mapper = mapper.into_mapper(block_count, device)?; + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let device = mapper.device_for(layer_idx, false).unwrap_or(device); + let ffn_up = QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.ffn_up.weight"), + device, + )?)?; + let ffn_down = QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.ffn_down.weight"), + device, + )?)?; + let mlp = Mlp { + ffn_up, + ffn_down, + i_size, + }; + let attn_norm = rms_norm( + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + rms_eps, + )?; + let ffn_norm = rms_norm( + ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + rms_eps, + )?; + layers.push(LayerWeights { + attn_qkv: QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.attn_qkv.weight"), + device, + )?)?, + attn_output: QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.attn_output.weight"), + device, + )?)?, + attn_norm, + ffn_norm, + mlp, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + sliding_window: context_window, + }) + } + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + output_norm, + output, + masks: HashMap::new(), + mapper: Some(mapper), + device: device.clone(), + cache: Cache::new(block_count, false), + max_seq_len: context_window, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result { + let (_b_sz, seq_len) = xs.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, xs.device())?) + }; + let mut xs = self.tok_embeddings.forward(xs)?; + let mut cache = self.cache.lock(); + for (i, layer) in self.layers.iter_mut().enumerate() { + if let Some(ref mapper) = self.mapper { + xs = mapper.map(xs, i)?; + } + let residual = &xs; + let ys = xs.apply(&layer.attn_norm)?; + let ys = layer.forward_attn( + &ys, + mask.as_ref() + .map(|m| m.to_device(xs.device()).unwrap()) + .as_ref(), + seqlen_offsets, + &mut cache[i], + )?; + let ys = (ys + residual)?; + let residual = &ys; + let ys = ys.apply(&layer.ffn_norm)?; + let ys = layer.mlp.forward(&ys)?; + xs = (ys + residual)? + } + let xs = xs.to_device(&self.device)?; + let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?; + self.output.forward(&xs) + } +} diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 8a567c6e4..fa586e137 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -15,7 +15,8 @@ use crate::xlora_models::NonGranularState; use crate::{deserialize_chat_template, do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata}; use crate::{ models::quantized_llama::ModelWeights as QLlama, models::quantized_phi2::ModelWeights as QPhi, - utils::tokens::get_token, xlora_models::XLoraModelWeights as XLoraQLlama, + models::quantized_phi3::ModelWeights as QPhi3, utils::tokens::get_token, + xlora_models::XLoraModelWeights as XLoraQLlama, }; use anyhow::{bail, Result}; use candle_core::quantized::{ @@ -40,6 +41,7 @@ enum Model { Llama(QLlama), Phi2(QPhi), XLoraLlama(XLoraQLlama), + Phi3(QPhi3), } pub struct GGUFPipeline { @@ -79,6 +81,7 @@ enum GGUFArchitecture { Mamba, Rwkv, Phi2, + Phi3, } impl FromStr for GGUFArchitecture { @@ -96,6 +99,7 @@ impl FromStr for GGUFArchitecture { "mamba" => Ok(GGUFArchitecture::Mamba), "rwkv" => Ok(GGUFArchitecture::Rwkv), "phi2" => Ok(GGUFArchitecture::Phi2), + "phi3" => Ok(GGUFArchitecture::Phi3), a => Err(format!("Unknown GGUF architecture `{a}`")), } } @@ -332,6 +336,9 @@ impl Loader for GGUFLoader { GGUFArchitecture::Phi2 => { Model::Phi2(QPhi::from_gguf(model, &mut file, device, mapper)?) } + GGUFArchitecture::Phi3 => { + Model::Phi3(QPhi3::from_gguf(model, &mut file, device, mapper)?) + } a => bail!("Unsupported architecture `{a:?}`"), }, ModelKind::XLoraGGUF => { @@ -405,16 +412,18 @@ impl Loader for GGUFLoader { Model::Llama(ref l) => l.max_seq_len, Model::Phi2(ref p) => p.max_seq_len, Model::XLoraLlama(ref xl) => xl.max_seq_len, + Model::Phi3(ref p) => p.max_seq_len, }; let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); let is_xlora = match &model { - Model::Llama(_) | Model::Phi2(_) => false, + Model::Llama(_) | Model::Phi2(_) | Model::Phi3(_) => false, Model::XLoraLlama(_) => !is_lora, }; let num_hidden_layers = match model { Model::Llama(ref model) => model.cache.lock().len(), Model::Phi2(ref model) => model.cache.lock().len(), Model::XLoraLlama(ref model) => model.cache.lock().len(), + Model::Phi3(ref model) => model.cache.lock().len(), }; let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer); Ok(Arc::new(Mutex::new(GGUFPipeline { @@ -488,6 +497,7 @@ impl Pipeline for GGUFPipeline { &self.non_granular_state, context_lens, ), + Model::Phi3(ref mut model) => model.forward(&input_ids, &seqlen_offsets), } } async fn sample( @@ -505,6 +515,7 @@ impl Pipeline for GGUFPipeline { Model::Llama(ref model) => model.device.clone(), Model::Phi2(ref model) => model.device.clone(), Model::XLoraLlama(ref model) => model.device.clone(), + Model::Phi3(ref model) => model.device.clone(), } } fn tokenizer(&self) -> Arc { @@ -547,6 +558,7 @@ impl Pipeline for GGUFPipeline { Model::Llama(ref model) => &model.cache, Model::Phi2(ref model) => &model.cache, Model::XLoraLlama(ref model) => &model.cache, + Model::Phi3(ref model) => &model.cache, } } }