Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concat in PhiRotaryEmbedding #268

Merged
merged 1 commit into from
May 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{ops::Mul, str::FromStr};

use candle_core::{quantized::QTensor, DType, Device, Result, Tensor};
use candle_core::{quantized::QTensor, DType, Device, IndexOp, Result, Tensor};
use candle_nn::{
layer_norm::{RmsNormNonQuantized, RmsNormQuantized},
Module, VarBuilder,
Expand Down Expand Up @@ -209,11 +209,13 @@ impl PhiRotaryEmbedding {
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
for offset in seqlen_offsets {
for (i, offset) in seqlen_offsets.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
let q_embed =
candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
let k_embed =
candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}
Expand Down
Loading