Skip to content

Commit

Permalink
Implement GPTQ quantization (#467)
Browse files Browse the repository at this point in the history
* Add the kernels

* Remove include of aten or torch

* Add the ffi bindings

* Sketch the forward method

* Handle input and output reshapes

* Add some features

* Improve compat of build.rs

* Fix workspace dep

* Finish merge

* Fixes

* Finish gptq gemm and add trait

* Add the cuda gptq matmul stub

* Remove default feature

* Correct conditional comp

* Add gguf qmatmul quantized support

* Implement matmul with qmethod in qllama

* Update readme of mistralrs quant

* int* to int64* in q_gemm.cu

* Add model and pipeline

* Add gptq loader selector

* Rename quantized_config -> quantization_config

* Fix g_idx shape

* Ensure WNA16

* Broadcast add

* Format

* int64_t* -> int* rollback

* Prep for correct types

* Finish merge

* Complete merge

* Update cargo lock

* Integrate with new i32 type

* Fixes

* More progress

* It doesnt crash

* Oops

* It works!

* Remove some todos

* Testing isq support

* Add to all non adapter models

* Clippy

* Avoid reallocating

* Add support for gptq to adapter models

* Add docs and logging

* Update docs

* Clippy

* Remove a todo
  • Loading branch information
EricLBuehler committed Aug 9, 2024
1 parent 249299b commit 1269bd8
Show file tree
Hide file tree
Showing 77 changed files with 6,009 additions and 1,244 deletions.
240 changes: 132 additions & 108 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"mistralrs",
"mistralrs-bench",
"mistralrs-vision",
"mistralrs-quant",
]
exclude = [
"mistralrs-paged_attn",
Expand All @@ -24,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Mistal.rs supports several model categories:
- Please suggest more by raising an issue!
- Tool calling: [docs](docs/TOOL_CALLING.md)
- Prompt chunking (only without PagedAttention for now): handle larger prompts where the activation size would cause an OOM by sending chunks
- Various quantizations (GGUF, GPTQ, ISQ): [docs](docs/QUANTS.md)


This is a demo of interactive mode with streaming running Phi 3 128k mini with quantization via ISQ to Q4K.
Expand Down
41 changes: 41 additions & 0 deletions docs/QUANTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Quantization in mistral.rs

Mistral.rs supports the following quantization:
- GGUF/GGML
- Q, K type
- Supported in GGUF/GGML and GGUF/GGML adapter models
- I quants coming!
- CPU, CUDA, Metal (all supported devices)
- GPTQ
- Supported in all plain and adapter models
- CUDA only
- ISQ
- Q, K type GGUF quants
- Supported in all plain and adapter models
- I quants coming!
- GPTQ quants coming!
- CPU, CUDA, Metal (all supported devices)

## Using a GGUF quantized model
- Use the `gguf` (cli) / `GGUF` (Python) model selector
- Provide the GGUF file

```
cargo run --features cuda -- -i gguf -f my-gguf-file.gguf
```

## Using ISQ
See the [docs](ISQ.md)

```
cargo run --features cuda -- -i --isq Q4K plain -m microsoft/Phi-3-mini-4k-instruct -a phi3
```

## Using a GPTQ quantized model
- Use the `plain` (cli) / `Plain` (Python) model selector
- Provide the model ID for the GPTQ model
- Mistral.rs will automatically detect and use GPTQ quantization.

```
cargo run --features cuda -- -i plain -m kaitchup/Phi-3-mini-4k-instruct-gptq-4bit -a phi3
```
5 changes: 3 additions & 2 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down Expand Up @@ -71,14 +71,15 @@ base64.workspace = true
bytemuck_derive = "1.7.0"
plotly = { version = "0.9.0", features = ["kaleido"], optional = true }
mistralrs-paged-attn = { version = "0.2.4", path = "../mistralrs-paged-attn", optional = true }
mistralrs-quant = { version = "0.2.0", path = "../mistralrs-quant" }
uuid = { version = "1.10.0", features = ["v4"] }
schemars = "0.8.21"

[features]
default = ["plotly"]
plotly = ["dep:plotly"]
pyo3_macros = ["pyo3"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"]
cudnn = ["candle-core/cudnn"]
metal = ["candle-core/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
Expand Down
27 changes: 27 additions & 0 deletions mistralrs-core/src/cuda/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ extern "C" {
pub(crate) fn count_nonzero_u8(d_in: *const c_void, N: u32) -> u32;
pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32) -> u32;
pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32;
pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32;
pub(crate) fn nonzero_bf16(
d_in: *const c_void,
N: u32,
Expand Down Expand Up @@ -65,6 +66,14 @@ extern "C" {
num_dims: u32,
d_out: *mut c_void,
);
pub(crate) fn nonzero_i32(
d_in: *const c_void,
N: u32,
num_nonzero: u32,
dims: *const c_void,
num_dims: u32,
d_out: *mut c_void,
);

pub(crate) fn bitwise_and_u8(
d_in1: *const c_void,
Expand All @@ -84,6 +93,12 @@ extern "C" {
d_out: *mut c_void,
N: u32,
);
pub(crate) fn bitwise_and_i32(
d_in1: *const c_void,
d_in2: *const c_void,
d_out: *mut c_void,
N: u32,
);
pub(crate) fn bitwise_or_u8(
d_in1: *const c_void,
d_in2: *const c_void,
Expand All @@ -102,6 +117,12 @@ extern "C" {
d_out: *mut c_void,
N: u32,
);
pub(crate) fn bitwise_or_i32(
d_in1: *const c_void,
d_in2: *const c_void,
d_out: *mut c_void,
N: u32,
);
pub(crate) fn bitwise_xor_u8(
d_in1: *const c_void,
d_in2: *const c_void,
Expand All @@ -120,4 +141,10 @@ extern "C" {
d_out: *mut c_void,
N: u32,
);
pub(crate) fn bitwise_xor_i32(
d_in1: *const c_void,
d_in2: *const c_void,
d_out: *mut c_void,
N: u32,
);
}
3 changes: 3 additions & 0 deletions mistralrs-core/src/cuda/nonzero_bitwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ COUNT_NONZERO_OP(double, f64)
COUNT_NONZERO_OP(uint8_t, u8)
COUNT_NONZERO_OP(uint32_t, u32)
COUNT_NONZERO_OP(int64_t, i64)
COUNT_NONZERO_OP(int32_t, i32)

__global__ void transform_indices(const uint32_t *temp_indices,
const uint32_t num_nonzero,
Expand Down Expand Up @@ -126,6 +127,7 @@ NONZERO_OP(double, f64)
NONZERO_OP(uint8_t, u8)
NONZERO_OP(uint32_t, u32)
NONZERO_OP(int64_t, i64)
NONZERO_OP(int32_t, i32)

template <typename T>
__global__ void bitwise_and__kernel(const T *d_in1, const T *d_in2, T *d_out,
Expand Down Expand Up @@ -207,3 +209,4 @@ void bitwise_xor(const T *d_in1, const T *d_in2, T *d_out, int N) {
BITWISE_OP(uint8_t, u8)
BITWISE_OP(uint32_t, u32)
BITWISE_OP(int64_t, i64)
BITWISE_OP(int32_t, i32)
10 changes: 10 additions & 0 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use candle_core::{
DType, Device, IndexOp, Result, Shape, Tensor, D,
};
use candle_nn::{Linear, Module, VarBuilder};
use mistralrs_quant::QuantMethod;
use serde::Deserialize;

pub use crate::layers_masker::CausalMasker;
Expand Down Expand Up @@ -429,6 +430,15 @@ impl MatMul {
matmul.forward(x)
}
}

/// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_half(x)
} else {
matmul.forward(x)
}
}
}

/// Computes softmax(QK^T*sqrt(d_k))V
Expand Down
61 changes: 26 additions & 35 deletions mistralrs-core/src/lora/loralinear.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
use std::{collections::HashMap, iter::zip, ops::Mul};
use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};

use candle_core::{
bail,
quantized::{QMatMul, QTensor},
Module, Result, Tensor,
};
use candle_core::{bail, DType, Module, Result, Tensor};
use candle_nn::{Linear, VarBuilder};
use either::Either;

use crate::layers::QLinear;
use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear};

use super::{
apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
LinearLayerLike, LoraConfig, LoraLinearConfig, Merge,
};

#[derive(Debug)]
pub struct LoraLinear {
old: QLinear,
old: Arc<dyn QuantMethod>,
a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
scale_adapters: Vec<f64>,
Expand Down Expand Up @@ -106,7 +100,9 @@ impl LoraLinear {
.to_dtype(a_adapters_stack.dtype())?;
let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
Ok(LoraLinear {
old: QLinear::from_parts(old.weight().clone(), old.bias().cloned()),
old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
Linear::new(old.weight().clone(), old.bias().cloned()),
))?),
a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
b_adapters: Either::Right((b_adapters_stack, b_adapters)),
scale_adapters,
Expand All @@ -116,7 +112,9 @@ impl LoraLinear {
})
} else {
Ok(LoraLinear {
old: QLinear::from_parts(old.weight().clone(), old.bias().cloned()),
old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
Linear::new(old.weight().clone(), old.bias().cloned()),
))?),
a_adapters: Either::Left(a_adapters),
b_adapters: Either::Left(b_adapters),
scale_adapters,
Expand Down Expand Up @@ -176,43 +174,36 @@ impl Merge for LoraLinear {
}

fn merge_weights(&mut self) -> Result<()> {
match &self.old.inner() {
QMatMul::QTensor(q) => {
let (mut w_base_layer, dtype) = (q.dequantize(&q.device())?, q.dtype());
for adapter in 0..self.scale_adapters.len() {
w_base_layer = (w_base_layer + self.get_delta_weight(adapter))?;
}
let new_w = QTensor::quantize(&w_base_layer, dtype)?;
self.old = QLinear::from_qparts(new_w, self.old.bias().cloned());
}
QMatMul::Tensor(w_base_layer) | QMatMul::TensorF16(w_base_layer) => {
let mut w_base_layer = w_base_layer.clone();
for adapter in 0..self.scale_adapters.len() {
w_base_layer = (w_base_layer + self.get_delta_weight(adapter))?;
}
self.old = QLinear::from_parts(w_base_layer, self.old.bias().cloned());
let mut w_base_layer: Option<Tensor> = None;
for adapter in 0..self.scale_adapters.len() {
if let Some(w_base_layer) = &mut w_base_layer {
*w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
} else {
w_base_layer = Some(self.get_delta_weight(adapter)?)
}
};
}
self.old
.add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
self.merged = true;
Ok(())
}
}

impl LinearLayerLike for LoraLinear {
fn inner(&mut self) -> Option<&mut candle_core::quantized::QMatMul> {
Arc::get_mut(&mut self.old).unwrap().get_qmatmul()
}
fn bias(&self) -> Option<&Tensor> {
self.old.bias()
unreachable!()
}
fn bias_mut(&mut self) -> Option<&mut Tensor> {
self.old.bias_mut()
unreachable!()
}
fn weight(&self) -> &Tensor {
unreachable!()
}
fn inner(&mut self) -> &mut QMatMul {
self.old.inner()
}
fn is_quant(&self) -> bool {
self.old.is_quant()
fn quantized_act_type(&self) -> Option<DType> {
self.old.quantized_act_type()
}
fn lora_forward(
&self,
Expand Down
18 changes: 9 additions & 9 deletions mistralrs-core/src/lora/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{collections::HashSet, fmt::Debug, sync::Arc};

use candle_core::{
quantized::{QMatMul, QTensor},
IndexOp, Result, Tensor, D,
DType, IndexOp, Result, Tensor, D,
};
use candle_nn::{init, Linear, Module, VarBuilder};
use loralinear::LoraLinear;
Expand Down Expand Up @@ -97,9 +97,9 @@ fn make_adapter(
}

/// Any layer that is linear-like.
pub trait LinearLayerLike: Debug + Merge + AdapterSwapper {
fn inner(&mut self) -> &mut QMatMul;
fn is_quant(&self) -> bool;
pub trait LinearLayerLike: Merge + AdapterSwapper {
fn quantized_act_type(&self) -> Option<DType>;
fn inner(&mut self) -> Option<&mut QMatMul>;
fn is_lora(&self) -> bool;
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
Expand Down Expand Up @@ -152,12 +152,12 @@ impl AdapterSwapper for Linear {
}

impl LinearLayerLike for Linear {
fn inner(&mut self) -> &mut QMatMul {
unreachable!()
}
fn bias(&self) -> Option<&Tensor> {
self.bias()
}
fn inner(&mut self) -> Option<&mut QMatMul> {
None
}
fn bias_mut(&mut self) -> Option<&mut Tensor> {
unreachable!()
}
Expand All @@ -173,8 +173,8 @@ impl LinearLayerLike for Linear {
) -> Result<Tensor> {
self.forward(x)
}
fn is_quant(&self) -> bool {
false
fn quantized_act_type(&self) -> Option<DType> {
None
}
fn is_lora(&self) -> bool {
false
Expand Down
12 changes: 6 additions & 6 deletions mistralrs-core/src/lora/qloralinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{collections::HashMap, iter::zip, ops::Mul};
use candle_core::{
bail,
quantized::{QMatMul, QTensor},
Module, Result, Tensor,
DType, Module, Result, Tensor,
};
use candle_nn::{Linear, VarBuilder};
use either::Either;
Expand Down Expand Up @@ -229,6 +229,9 @@ impl Merge for QLoraLinear {
}

impl LinearLayerLike for QLoraLinear {
fn inner(&mut self) -> Option<&mut candle_core::quantized::QMatMul> {
Some(&mut self.old)
}
fn bias(&self) -> Option<&Tensor> {
None
}
Expand All @@ -238,11 +241,8 @@ impl LinearLayerLike for QLoraLinear {
fn weight(&self) -> &Tensor {
unimplemented!()
}
fn inner(&mut self) -> &mut QMatMul {
&mut self.old
}
fn is_quant(&self) -> bool {
true
fn quantized_act_type(&self) -> Option<DType> {
Some(DType::F32)
}
fn lora_forward(
&self,
Expand Down
Loading

0 comments on commit 1269bd8

Please sign in to comment.