diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index ef62a5bbb..1c0a1551e 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -208,7 +208,7 @@ impl Attention { .contiguous()?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 4c1ebf22e..aee6b741c 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -89,7 +89,8 @@ impl CausalSelfAttention { .contiguous()?; } - let (k, v) = crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v)?; + let (k, v) = + crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?; let k = repeat_kv(k, self.num_attention_heads / self.num_key_value_heads)?.contiguous()?; let v = repeat_kv(v, self.num_attention_heads / self.num_key_value_heads)?.contiguous()?; diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 5d75aefc2..49d402f53 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -165,6 +165,7 @@ impl Attention { v, attention_mask, self.sliding_window, + false, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index b1ba38317..6dbc78826 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -129,6 +129,7 @@ impl Attention { v, attention_mask, self.sliding_window, + false, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 37b0aac14..89813cbde 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -194,7 +194,7 @@ impl Attention { .contiguous()?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.num_heads / self.num_kv_heads)?.contiguous()?; let v = repeat_kv(v, self.num_heads / self.num_kv_heads)?.contiguous()?; diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index e5a7307aa..fe9acd402 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -125,6 +125,7 @@ impl Attention { v, attention_mask, self.sliding_window, + true, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index ae3b3bede..a8b6a664c 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -157,7 +157,7 @@ impl LayerWeights { .transpose(1, 2)?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?; diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index f3e913688..8ae6fc3c9 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -80,7 +80,7 @@ impl LayerWeights { let q = self.forward(&q, seqlen_offsets)?.contiguous()?; let k = self.forward(&k, seqlen_offsets)?; - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.n_head / self.n_kv_head)?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?; diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 114a99b6f..6af7b8b1b 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -97,8 +97,14 @@ impl LayerWeights { let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?; let k = self.apply_rotary_emb(&k, seqlen_offsets)?; - let (k, v, attn_mask) = - Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?; + let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window( + kv_cache, + k, + v, + mask, + Some(self.sliding_window), + true, + )?; let k = repeat_kv(k, self.n_head / self.n_kv_head)?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?; diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index 2c17afd6b..46037030f 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -158,7 +158,7 @@ impl Attention { .contiguous()?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/pipeline/cache_manager.rs b/mistralrs-core/src/pipeline/cache_manager.rs index 71590d91e..a71a4d2cc 100644 --- a/mistralrs-core/src/pipeline/cache_manager.rs +++ b/mistralrs-core/src/pipeline/cache_manager.rs @@ -80,13 +80,20 @@ impl Cache { cache: &mut Option<(Tensor, Tensor)>, k: Tensor, v: Tensor, + slow_cat: bool, ) -> Result<(Tensor, Tensor), candle_core::Error> { let (k, v) = match &*cache { None => (k, v), Some((k_cache, v_cache)) => { - let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?; - let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?; - (k, v) + if slow_cat { + let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?; + let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?; + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } } }; *cache = Some((k.clone(), v.clone())); @@ -100,6 +107,7 @@ impl Cache { v: Tensor, attention_mask: Option<&Tensor>, sliding_window: Option, + slow_cat: bool, ) -> Result<(Tensor, Tensor, Option), candle_core::Error> { let (k, v, attention_mask) = match cache.clone() { None => (k, v, attention_mask.cloned()), @@ -132,8 +140,15 @@ impl Cache { } } } - let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + let (k, v) = if !slow_cat { + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v) + } else { + let k = Tensor::cat(&[prev_k, k], 2)?.contiguous()?; + let v = Tensor::cat(&[prev_v, v], 2)?.contiguous()?; + (k, v) + }; (k, v, mask) } }; diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index de64b40d2..d98ec58f6 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -314,7 +314,7 @@ impl Attention { .contiguous()?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index 6e0298ef3..8be88991e 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -99,7 +99,8 @@ impl CausalSelfAttention { .contiguous()?; } - let (k, v) = crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v)?; + let (k, v) = + crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?; let k = repeat_kv(k, self.num_attention_heads / self.num_key_value_heads)?.contiguous()?; let v = repeat_kv(v, self.num_attention_heads / self.num_key_value_heads)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 9acf93643..c45923c13 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -280,6 +280,7 @@ impl Attention { v, attention_mask, self.sliding_window, + false, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 195e622aa..ff7258705 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -180,6 +180,7 @@ impl Attention { v, attention_mask, self.sliding_window, + false, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index bf32bd9c7..55248e9b1 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -283,7 +283,7 @@ impl Attention { .contiguous()?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.num_heads / self.num_kv_heads)?.contiguous()?; let v = repeat_kv(v, self.num_heads / self.num_kv_heads)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index b658e9c64..99f2131f5 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -144,6 +144,7 @@ impl Attention { v, attention_mask, self.sliding_window, + true, )?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index b3cd75614..683f04dc9 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -228,7 +228,7 @@ impl LayerWeights { .transpose(1, 2)?; } - let (k, v) = Cache::update_kv_cache(kv_cache, k, v)?; + let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?; diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 8b2de5365..dada5ccb2 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -147,8 +147,14 @@ impl LayerWeights { let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?; let k = self.apply_rotary_emb(&k, seqlen_offsets)?; - let (k, v, attn_mask) = - Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?; + let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window( + kv_cache, + k, + v, + mask, + Some(self.sliding_window), + true, + )?; let k = repeat_kv(k, self.n_head / self.n_kv_head)?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?;