diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 2ca529117..deb9c04b4 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -39,6 +39,14 @@ pub struct Config { pub(crate) rope_theta: f64, pub(crate) sliding_window: Option, pub(crate) use_flash_attn: bool, + pub(crate) head_dim: Option, +} + +impl Config { + pub(crate) fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } } #[derive(Debug, Clone)] @@ -141,7 +149,6 @@ struct Attention { num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, - hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, sliding_window: Option, @@ -159,7 +166,7 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let head_dim = hidden_sz / num_heads; + let head_dim = cfg.head_dim(); let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; @@ -173,7 +180,6 @@ impl Attention { num_kv_heads, num_kv_groups, head_dim, - hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, sliding_window: cfg.sliding_window, @@ -282,11 +288,9 @@ impl Attention { attn_output = attn_output.to_dtype(DType::F32)?; } attn_output = if attention_mask.is_some() { - attn_output - .transpose(1, 2)? - .reshape(&[b_sz, q_len, self.hidden_size])? + attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { - attn_output.reshape(&[b_sz, q_len, self.hidden_size])? + attn_output.reshape((b_sz, q_len, ()))? }; let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?; @@ -419,7 +423,7 @@ impl Model { cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), )?; - let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let head_dim = cfg.head_dim(); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 0a70972a8..7f7a00afa 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -114,6 +114,7 @@ struct MistralBasicConfig { rms_norm_eps: f64, rope_theta: f64, sliding_window: Option, + head_dim: Option, } impl MistralBasicConfig { @@ -132,6 +133,7 @@ impl MistralBasicConfig { rope_theta: basic_config.rope_theta, sliding_window: basic_config.sliding_window, use_flash_attn, + head_dim: basic_config.head_dim, }) } } diff --git a/mistralrs-core/src/vision_models/idefics2.rs b/mistralrs-core/src/vision_models/idefics2.rs index c88e40b78..ff47c5b40 100644 --- a/mistralrs-core/src/vision_models/idefics2.rs +++ b/mistralrs-core/src/vision_models/idefics2.rs @@ -187,6 +187,7 @@ impl From for mistral::Config { rope_theta: val.rope_theta, sliding_window: val.sliding_window, use_flash_attn: val.use_flash_attn, + head_dim: None, } } } diff --git a/mistralrs-core/src/vision_models/llava/config.rs b/mistralrs-core/src/vision_models/llava/config.rs index 87deb1667..87196ea57 100644 --- a/mistralrs-core/src/vision_models/llava/config.rs +++ b/mistralrs-core/src/vision_models/llava/config.rs @@ -94,6 +94,7 @@ impl Config { rope_theta: self.text_config.rope_theta as f64, sliding_window: self.text_config.sliding_window, use_flash_attn: self.use_flash_attn, + head_dim: None, } } diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 22470b9fa..ff7f8b9c7 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -161,7 +161,7 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let head_dim = hidden_sz / num_heads; + let head_dim = cfg.head_dim(); let q_proj = linear_no_bias( hidden_sz, num_heads * head_dim, @@ -469,7 +469,7 @@ impl XLoraModel { cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), )?; - let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let head_dim = cfg.head_dim(); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); let mut count = 0;