Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mistral nemo #595

Merged
merged 2 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ pub struct Config {
pub(crate) rope_theta: f64,
pub(crate) sliding_window: Option<usize>,
pub(crate) use_flash_attn: bool,
pub(crate) head_dim: Option<usize>,
}

impl Config {
pub(crate) fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -141,7 +149,6 @@ struct Attention {
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
sliding_window: Option<usize>,
Expand All @@ -159,7 +166,7 @@ impl Attention {
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let head_dim = cfg.head_dim();
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
Expand All @@ -173,7 +180,6 @@ impl Attention {
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
use_flash_attn: cfg.use_flash_attn,
sliding_window: cfg.sliding_window,
Expand Down Expand Up @@ -282,11 +288,9 @@ impl Attention {
attn_output = attn_output.to_dtype(DType::F32)?;
}
attn_output = if attention_mask.is_some() {
attn_output
.transpose(1, 2)?
.reshape(&[b_sz, q_len, self.hidden_size])?
attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
} else {
attn_output.reshape(&[b_sz, q_len, self.hidden_size])?
attn_output.reshape((b_sz, q_len, ()))?
};

let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?;
Expand Down Expand Up @@ -419,7 +423,7 @@ impl Model {
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
)?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let head_dim = cfg.head_dim();
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/pipeline/normal_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ struct MistralBasicConfig {
rms_norm_eps: f64,
rope_theta: f64,
sliding_window: Option<usize>,
head_dim: Option<usize>,
}

impl MistralBasicConfig {
Expand All @@ -132,6 +133,7 @@ impl MistralBasicConfig {
rope_theta: basic_config.rope_theta,
sliding_window: basic_config.sliding_window,
use_flash_attn,
head_dim: basic_config.head_dim,
})
}
}
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/vision_models/idefics2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ impl From<TextConfig> for mistral::Config {
rope_theta: val.rope_theta,
sliding_window: val.sliding_window,
use_flash_attn: val.use_flash_attn,
head_dim: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/vision_models/llava/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ impl Config {
rope_theta: self.text_config.rope_theta as f64,
sliding_window: self.text_config.sliding_window,
use_flash_attn: self.use_flash_attn,
head_dim: None,
}
}

Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/xlora_models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl Attention {
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let head_dim = cfg.head_dim();
let q_proj = linear_no_bias(
hidden_sz,
num_heads * head_dim,
Expand Down Expand Up @@ -469,7 +469,7 @@ impl XLoraModel {
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
)?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let head_dim = cfg.head_dim();
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
let mut count = 0;
Expand Down
Loading