diff --git a/Cargo.toml b/Cargo.toml index 567d803ba..f10bc9a6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ license = "MIT" [workspace.dependencies] anyhow = { version = "1.0.80", feature = "std" } -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 06b444b48..0a4e96ff6 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,8 +17,8 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-transformers = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" } -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", optional = true } +candle-transformers = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", optional = true, branch = "remove_candle_layer_norm" } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index ea5dc1a9d..60bce21a0 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -14,10 +14,7 @@ use candle_core::{ quantized::{gguf_file, QMatMul, QTensor}, DType, Device, IndexOp, Result, Tensor, WithDType, }; -use candle_nn::{ - layer_norm::{RmsNormNonQuantized, RmsNormQuantized}, - Linear, Module, VarBuilder, -}; +use candle_nn::{Linear, Module, VarBuilder}; use once_cell::sync::Lazy; // (bs, tgt_len, past_kv_len) @@ -28,7 +25,6 @@ use crate::{cublaslt::CUBLASLT_HANDLE, models::phi3, INHIBIT_GEMM_F16}; #[derive(Debug, Clone)] pub struct RmsNorm { - inner: candle_nn::RmsNorm, eps: f64, weight: Tensor, } @@ -37,47 +33,37 @@ impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; let w = inner.inner().weight().clone(); - Ok(Self { - inner, - eps, - weight: w, - }) + Ok(Self { eps, weight: w }) } pub fn from_w(w: Tensor, eps: f64) -> Result { - let inner = candle_nn::RmsNorm::::new(w.clone(), eps); - Ok(Self { - inner, - eps, - weight: w, - }) + Ok(Self { eps, weight: w }) } } impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result { - if x.device().is_cpu() { - // Handle device mapping case - return candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32); - } - self.inner.forward(x) + candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) } } #[derive(Debug, Clone)] pub struct QRmsNorm { - inner: candle_nn::RmsNorm, + eps: f64, + weight: Tensor, } impl QRmsNorm { pub fn new(scale: QTensor, eps: f32) -> Result { let scale = scale.dequantize(&scale.device())?; - let inner = candle_nn::RmsNorm::::new(scale, eps as f64); - Ok(Self { inner }) + Ok(Self { + eps: eps as f64, + weight: scale, + }) } pub fn forward(&self, x: &Tensor) -> Result { - self.inner.forward(x) + candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) } }