Skip to content

Commit

Permalink
Support rope scaling for phi3 models (Phi3 128k)
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed Jul 5, 2024
1 parent 7cffa66 commit d763d4c
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 8 deletions.
7 changes: 5 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ env_logger = "0.10.1"
tracing = "0.1.40"
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }
chrono = { version = "0.4.31", features = ["clock"] }
either = "1.9.0"
either = { version = "1.13.0", features = ["serde"] }
dirs = "5.0.1"
kernels = {path = "./kernels", version="0.1.0"}

Expand Down
2 changes: 2 additions & 0 deletions src/openai/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ impl LlamaConfig {
sliding_window: None,
hidden_act: None,
tie_word_embeddings: false,
rope_scaling: None,
original_max_position_embeddings: None,
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/openai/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
pub mod llama;
pub mod phi3;
pub mod qwen2;
use either::Either;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Deserialize, Debug, Clone)]
pub struct RopeScaling(#[serde(with = "either::serde_untagged")] pub Either<Vec<f64>, String>);

#[derive(Debug, Clone)]
pub struct Config {
Expand All @@ -19,6 +25,8 @@ pub struct Config {
pub sliding_window: Option<usize>,
pub hidden_act: Option<candle_nn::Activation>,
pub tie_word_embeddings: bool,
pub rope_scaling: Option<HashMap<String, RopeScaling>>,
pub original_max_position_embeddings: Option<usize>,
}

impl Config {
Expand Down
117 changes: 112 additions & 5 deletions src/openai/models/phi3.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// This implementation is based on:
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
use super::Config;
use super::{Config, RopeScaling};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use either::Either;
use std::collections::HashMap;
use std::iter::zip;
use std::sync::Arc;

// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
#[derive(Debug, Clone, serde::Deserialize)]
pub struct PhiConfig {
pub vocab_size: usize,
Expand All @@ -24,8 +25,9 @@ pub struct PhiConfig {
pub rope_theta: f64,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub rope_scaling: Option<String>,
pub rope_scaling: Option<HashMap<String, RopeScaling>>,
pub max_position_embeddings: usize,
pub original_max_position_embeddings: Option<usize>,
pub sliding_window: Option<usize>,
}

Expand All @@ -47,6 +49,8 @@ impl PhiConfig {
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
tie_word_embeddings: false,
rope_scaling: self.rope_scaling,
original_max_position_embeddings: self.original_max_position_embeddings,
}
}
}
Expand All @@ -55,6 +59,9 @@ impl PhiConfig {
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
sin_long: Option<Tensor>,
cos_long: Option<Tensor>,
original_max_position_embeddings: Option<usize>,
}

impl RotaryEmbedding {
Expand All @@ -71,9 +78,89 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;

if let Some(rope_scaling) = &cfg.rope_scaling {
match (
&rope_scaling["short_factor"],
&rope_scaling["long_factor"],
&rope_scaling["type"],
) {
(
RopeScaling(Either::Left(short_factor)),
RopeScaling(Either::Left(long_factor)),
RopeScaling(Either::Right(tp)),
) => {
let scale = cfg.max_seq_len as f64
/ cfg.original_max_position_embeddings.unwrap() as f64;
let scaling_factor = if scale <= 1.0 {
1.0
} else {
match tp.as_str() {
"su" | "longrope" => (1.0
+ scale.ln()
/ (cfg.original_max_position_embeddings.unwrap() as f64).ln())
.sqrt(),
"yarn" => 0.1 * scale.ln() + 1.0,
_ => 1.0,
}
};
// Calculate inv freqs for short, long
let inv_freq_long = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
(1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)))
as f32
})
.collect::<Vec<_>>();
let inv_freq_short = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
(1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)))
as f32
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq_long.len();

let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;

// Calculate sin,cos for long
let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?
.to_dtype(DType::F32)?;
let freqs_long = t.matmul(&inv_freq_long)?;
let long_sin = (freqs_long.sin()? * scaling_factor)?;
let long_cos = (freqs_long.cos()? * scaling_factor)?;

// Calculate sin,cos for short
let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?
.to_dtype(DType::F32)?;
let freqs_short = t.matmul(&inv_freq_short)?;
let short_sin = (freqs_short.sin()? * scaling_factor)?;
let short_cos = (freqs_short.cos()? * scaling_factor)?;

return Ok(Self {
sin: short_sin,
cos: short_cos,
sin_long: Some(long_sin),
cos_long: Some(long_cos),
original_max_position_embeddings: cfg.original_max_position_embeddings,
});
}
_ => {
panic!("Unknown config for rope scaling!")
}
}
}

Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
sin_long: None,
cos_long: None,
original_max_position_embeddings: None,
})
}

Expand All @@ -84,8 +171,28 @@ impl RotaryEmbedding {
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;

let (cos, sin) = if self.sin_long.as_ref().is_some()
&& self.cos_long.as_ref().is_some()
&& self.original_max_position_embeddings.is_some()
&& seqlen_offset > self.original_max_position_embeddings.unwrap()
{
let cos = self
.cos_long
.as_ref()
.unwrap()
.narrow(0, seqlen_offset, seq_len)?;
let sin = self
.sin_long
.as_ref()
.unwrap()
.narrow(0, seqlen_offset, seq_len)?;
(cos, sin)
} else {
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
(cos, sin)
};
let q_embed = candle_nn::rotary_emb::rope(&q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k, &cos, &sin)?;
Ok((q_embed, k_embed))
Expand Down
2 changes: 2 additions & 0 deletions src/openai/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ impl QwenConfig {
sliding_window: Some(self.sliding_window),
hidden_act: Some(self.hidden_act),
tie_word_embeddings: self.tie_word_embeddings,
rope_scaling: None,
original_max_position_embeddings: None,
}
}
}
Expand Down

0 comments on commit d763d4c

Please sign in to comment.