From 1269bd8ab6246e4b59011a328e8d32467703428e Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:49:44 -0400 Subject: [PATCH] Implement GPTQ quantization (#467) * Add the kernels * Remove include of aten or torch * Add the ffi bindings * Sketch the forward method * Handle input and output reshapes * Add some features * Improve compat of build.rs * Fix workspace dep * Finish merge * Fixes * Finish gptq gemm and add trait * Add the cuda gptq matmul stub * Remove default feature * Correct conditional comp * Add gguf qmatmul quantized support * Implement matmul with qmethod in qllama * Update readme of mistralrs quant * int* to int64* in q_gemm.cu * Add model and pipeline * Add gptq loader selector * Rename quantized_config -> quantization_config * Fix g_idx shape * Ensure WNA16 * Broadcast add * Format * int64_t* -> int* rollback * Prep for correct types * Finish merge * Complete merge * Update cargo lock * Integrate with new i32 type * Fixes * More progress * It doesnt crash * Oops * It works! * Remove some todos * Testing isq support * Add to all non adapter models * Clippy * Avoid reallocating * Add support for gptq to adapter models * Add docs and logging * Update docs * Clippy * Remove a todo --- Cargo.lock | 240 ++- Cargo.toml | 5 +- README.md | 1 + docs/QUANTS.md | 41 + mistralrs-core/Cargo.toml | 5 +- mistralrs-core/src/cuda/ffi.rs | 27 + mistralrs-core/src/cuda/nonzero_bitwise.cu | 3 + mistralrs-core/src/layers.rs | 10 + mistralrs-core/src/lora/loralinear.rs | 61 +- mistralrs-core/src/lora/mod.rs | 18 +- mistralrs-core/src/lora/qloralinear.rs | 12 +- mistralrs-core/src/model_loader.rs | 6 +- mistralrs-core/src/models/gemma.rs | 270 ++- mistralrs-core/src/models/gemma2.rs | 280 ++- mistralrs-core/src/models/llama.rs | 189 +- mistralrs-core/src/models/mistral.rs | 222 ++- mistralrs-core/src/models/mixtral.rs | 219 +- mistralrs-core/src/models/phi2.rs | 241 ++- mistralrs-core/src/models/phi3.rs | 135 +- mistralrs-core/src/models/quantized_llama.rs | 176 +- mistralrs-core/src/models/quantized_phi2.rs | 54 +- mistralrs-core/src/models/quantized_phi3.rs | 71 +- .../src/models/quantized_starcoder2.rs | 83 +- mistralrs-core/src/models/qwen2.rs | 255 ++- mistralrs-core/src/models/starcoder2.rs | 240 ++- mistralrs-core/src/ops.rs | 45 + mistralrs-core/src/pipeline/macros.rs | 2 +- mistralrs-core/src/pipeline/mod.rs | 2 + mistralrs-core/src/pipeline/normal.rs | 41 +- mistralrs-core/src/pipeline/normal_loaders.rs | 25 +- mistralrs-core/src/pipeline/paths.rs | 13 +- mistralrs-core/src/pipeline/vision.rs | 33 +- mistralrs-core/src/toml_selector.rs | 6 +- mistralrs-core/src/utils/model_config.rs | 2 +- mistralrs-core/src/utils/varbuilder_utils.rs | 46 +- mistralrs-core/src/vision_models/idefics2.rs | 1 + .../src/vision_models/llava/config.rs | 2 + mistralrs-core/src/xlora_models/gemma.rs | 96 +- mistralrs-core/src/xlora_models/gemma2.rs | 112 +- mistralrs-core/src/xlora_models/llama.rs | 90 +- mistralrs-core/src/xlora_models/mistral.rs | 96 +- mistralrs-core/src/xlora_models/mixtral.rs | 107 +- mistralrs-core/src/xlora_models/phi2.rs | 79 +- mistralrs-core/src/xlora_models/phi3.rs | 75 +- mistralrs-core/src/xlora_models/starcoder2.rs | 84 +- mistralrs-paged-attn/src/backend/mod.rs | 1 + mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-pyo3/src/lib.rs | 9 +- mistralrs-quant/Cargo.toml | 25 + mistralrs-quant/README.md | 10 + mistralrs-quant/build.rs | 54 + mistralrs-quant/kernels/gptq/compat.cuh | 60 + mistralrs-quant/kernels/gptq/matrix_view.cuh | 290 +++ mistralrs-quant/kernels/gptq/q_gemm.cu | 1761 +++++++++++++++++ mistralrs-quant/kernels/gptq/qdq_2.cuh | 70 + mistralrs-quant/kernels/gptq/qdq_3.cuh | 144 ++ mistralrs-quant/kernels/gptq/qdq_4.cuh | 122 ++ mistralrs-quant/kernels/gptq/qdq_8.cuh | 24 + mistralrs-quant/kernels/gptq/qdq_util.cuh | 51 + mistralrs-quant/src/gguf/mod.rs | 105 + mistralrs-quant/src/gptq/ffi.rs | 56 + mistralrs-quant/src/gptq/gptq_cpu.rs | 55 + mistralrs-quant/src/gptq/gptq_cuda.rs | 305 +++ mistralrs-quant/src/gptq/mod.rs | 11 + mistralrs-quant/src/lib.rs | 189 ++ mistralrs-quant/src/unquantized/mod.rs | 66 + mistralrs/examples/anymoe/main.rs | 2 +- mistralrs/examples/anymoe_lora/main.rs | 2 +- mistralrs/examples/gemma2/main.rs | 2 +- mistralrs/examples/grammar/main.rs | 2 +- mistralrs/examples/isq/main.rs | 2 +- mistralrs/examples/lora/main.rs | 2 +- mistralrs/examples/lora_activation/main.rs | 2 +- mistralrs/examples/paged_attn/main.rs | 2 +- mistralrs/examples/simple/main.rs | 2 +- mistralrs/examples/tools/main.rs | 2 +- mistralrs/examples/xlora/main.rs | 2 +- 77 files changed, 6009 insertions(+), 1244 deletions(-) create mode 100644 docs/QUANTS.md create mode 100644 mistralrs-quant/Cargo.toml create mode 100644 mistralrs-quant/README.md create mode 100644 mistralrs-quant/build.rs create mode 100644 mistralrs-quant/kernels/gptq/compat.cuh create mode 100644 mistralrs-quant/kernels/gptq/matrix_view.cuh create mode 100644 mistralrs-quant/kernels/gptq/q_gemm.cu create mode 100644 mistralrs-quant/kernels/gptq/qdq_2.cuh create mode 100644 mistralrs-quant/kernels/gptq/qdq_3.cuh create mode 100644 mistralrs-quant/kernels/gptq/qdq_4.cuh create mode 100644 mistralrs-quant/kernels/gptq/qdq_8.cuh create mode 100644 mistralrs-quant/kernels/gptq/qdq_util.cuh create mode 100644 mistralrs-quant/src/gguf/mod.rs create mode 100644 mistralrs-quant/src/gptq/ffi.rs create mode 100644 mistralrs-quant/src/gptq/gptq_cpu.rs create mode 100644 mistralrs-quant/src/gptq/gptq_cuda.rs create mode 100644 mistralrs-quant/src/gptq/mod.rs create mode 100644 mistralrs-quant/src/lib.rs create mode 100644 mistralrs-quant/src/unquantized/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b81c94582..30bc422a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -408,9 +408,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" dependencies = [ "bytemuck_derive", ] @@ -440,9 +440,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "bzip2" @@ -468,7 +468,7 @@ dependencies = [ [[package]] name = "candle-core" version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e09d7f3#9e09d7f3ef36bb5b30b6aa963992392271d27f49" +source = "git+https://github.com/EricLBuehler/candle.git?rev=57c5599d#57c5599d357f6ef5d34462b51645c3185d0b9b7d" dependencies = [ "accelerate-src", "byteorder", @@ -495,7 +495,7 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e09d7f3#9e09d7f3ef36bb5b30b6aa963992392271d27f49" +source = "git+https://github.com/EricLBuehler/candle.git?rev=57c5599d#57c5599d357f6ef5d34462b51645c3185d0b9b7d" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -506,7 +506,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e09d7f3#9e09d7f3ef36bb5b30b6aa963992392271d27f49" +source = "git+https://github.com/EricLBuehler/candle.git?rev=57c5599d#57c5599d357f6ef5d34462b51645c3185d0b9b7d" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -514,7 +514,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e09d7f3#9e09d7f3ef36bb5b30b6aa963992392271d27f49" +source = "git+https://github.com/EricLBuehler/candle.git?rev=57c5599d#57c5599d357f6ef5d34462b51645c3185d0b9b7d" dependencies = [ "metal", "once_cell", @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e09d7f3#9e09d7f3ef36bb5b30b6aa963992392271d27f49" +source = "git+https://github.com/EricLBuehler/candle.git?rev=57c5599d#57c5599d357f6ef5d34462b51645c3185d0b9b7d" dependencies = [ "accelerate-src", "candle-core", @@ -542,9 +542,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.6" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "504bdec147f2cc13c8b57ed9401fd8a147cc66b67ad5cb241394244f2c947549" dependencies = [ "jobserver", "libc", @@ -568,7 +568,7 @@ version = "0.13.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6026d8cd82ada8bbcfe337805dd1eb6afdc9e80fa4d57e977b3a36315e0c5525" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.3.0", "lazy_static", "num-traits", "regex", @@ -625,9 +625,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.11" +version = "4.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" +checksum = "c937d4061031a6d0c8da4b9a4f98a172fc2976dfb1c19213a9cf7d0d3c837e36" dependencies = [ "clap_builder", "clap_derive", @@ -635,9 +635,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.11" +version = "4.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" +checksum = "85379ba512b21a328adf887e85f7742d12e96eb31f3ef077df4ffc26b506ffed" dependencies = [ "anstream", "anstyle", @@ -647,9 +647,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.11" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" +checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1094,9 +1094,9 @@ dependencies = [ [[package]] name = "dunce" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" @@ -1246,9 +1246,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" dependencies = [ "crc32fast", "miniz_oxide", @@ -1604,7 +1604,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.2.6", + "indexmap 2.3.0", "slab", "tokio", "tokio-util", @@ -1803,9 +1803,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1909,9 +1909,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -2240,9 +2240,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" +checksum = "f4bf71af278c578cbcc91d0b1ff092910208bc86f7b3750364642bd424e3dcd3" dependencies = [ "serde", "serde_json", @@ -2250,9 +2250,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" +checksum = "b22bcb360a130647646fdda70742bff3cea5d4fe562d8a943f26675c84206088" dependencies = [ "minijinja", "serde", @@ -2306,7 +2306,7 @@ dependencies = [ "candle-core", "either", "image", - "indexmap 2.2.6", + "indexmap 2.3.0", "mistralrs-core", "serde", "serde_json", @@ -2357,7 +2357,7 @@ dependencies = [ "half", "hf-hub", "image", - "indexmap 2.2.6", + "indexmap 2.3.0", "indicatif", "intel-mkl-src", "itertools 0.13.0", @@ -2365,6 +2365,7 @@ dependencies = [ "minijinja", "minijinja-contrib", "mistralrs-paged-attn", + "mistralrs-quant", "mistralrs-vision", "once_cell", "plotly", @@ -2375,7 +2376,7 @@ dependencies = [ "rayon", "regex-automata 0.4.7", "reqwest", - "rustc-hash 2.0.0", + "rustc-hash", "schemars", "serde", "serde_json", @@ -2414,7 +2415,7 @@ dependencies = [ "either", "futures", "image", - "indexmap 2.2.6", + "indexmap 2.3.0", "intel-mkl-src", "mistralrs-core", "pyo3", @@ -2425,6 +2426,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "mistralrs-quant" +version = "0.2.4" +dependencies = [ + "bindgen_cuda 0.1.5", + "candle-core", + "candle-nn", + "half", + "lazy_static", + "serde", +] + [[package]] name = "mistralrs-server" version = "0.2.4" @@ -2439,7 +2452,7 @@ dependencies = [ "either", "futures", "image", - "indexmap 2.2.6", + "indexmap 2.3.0", "intel-mkl-src", "mistralrs-core", "once_cell", @@ -2617,18 +2630,18 @@ dependencies = [ [[package]] name = "num_enum" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02339744ee7253741199f897151b38e72257d13802d4ee837285cc2990a90845" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" dependencies = [ "num_enum_derive", ] [[package]] name = "num_enum_derive" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "681030a937600a36906c185595136d26abfebb4aa9c65701cefcaf8578bb982b" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2663,9 +2676,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.2" +version = "0.36.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" dependencies = [ "memchr", ] @@ -2980,7 +2993,7 @@ dependencies = [ "dunce", "serde", "serde_json", - "zip 2.1.5", + "zip 2.1.6", ] [[package]] @@ -3010,9 +3023,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro-crate" @@ -3081,7 +3097,7 @@ dependencies = [ "either", "eyre", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.3.0", "indoc", "libc", "memoffset", @@ -3161,16 +3177,17 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quinn" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +checksum = "b22d8e7369034b9a7132bc2008cac12f2013c8132b45e0554e6e20e2617f2156" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 1.1.0", + "rustc-hash", "rustls", + "socket2", "thiserror", "tokio", "tracing", @@ -3178,14 +3195,14 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.3" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +checksum = "ba92fb39ec7ad06ca2582c0ca834dfeadcaf06ddfc8e635c80aa7e1c05315fdd" dependencies = [ "bytes", "rand", "ring", - "rustc-hash 1.1.0", + "rustc-hash", "rustls", "slab", "thiserror", @@ -3202,6 +3219,7 @@ dependencies = [ "libc", "once_cell", "socket2", + "tracing", "windows-sys 0.52.0", ] @@ -3350,9 +3368,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3506,12 +3524,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.0.0" @@ -3557,9 +3569,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -3567,9 +3579,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" @@ -3596,9 +3608,9 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "safetensors" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" +checksum = "7725d4d98fa515472f43a6e2bbf956c48e06b89bb50593a040e5945160214450" dependencies = [ "serde", "serde_json", @@ -3689,18 +3701,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.205" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "e33aedb1a7135da52b7c21791455563facbbcc43d0f0f66165b42c21b3dfb150" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.205" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "692d6f5ac90220161d6774db30c662202721e64aed9058d2c394f451261420c1" dependencies = [ "proc-macro2", "quote", @@ -3720,11 +3732,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3781,7 +3794,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.6", + "indexmap 2.3.0", "serde", "serde_derive", "serde_json", @@ -3844,9 +3857,9 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", "mio 0.8.11", @@ -4096,20 +4109,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4242,9 +4256,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.39.1" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", @@ -4315,21 +4329,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.17", + "toml_edit 0.22.20", ] [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] @@ -4340,22 +4354,22 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.3.0", "toml_datetime", "winnow 0.5.40", ] [[package]] name = "toml_edit" -version = "0.22.17" +version = "0.22.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" +checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.3.0", "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.16", + "winnow 0.6.18", ] [[package]] @@ -4590,9 +4604,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" dependencies = [ "base64 0.22.1", "flate2", @@ -4630,7 +4644,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.3.0", "serde", "serde_json", "utoipa-gen", @@ -4862,11 +4876,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4912,6 +4926,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -5044,9 +5067,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.16" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" dependencies = [ "memchr", ] @@ -5102,6 +5125,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] @@ -5168,16 +5192,16 @@ dependencies = [ "crossbeam-utils", "displaydoc", "flate2", - "indexmap 2.2.6", + "indexmap 2.3.0", "num_enum", "thiserror", ] [[package]] name = "zip" -version = "2.1.5" +version = "2.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b895748a3ebcb69b9d38dcfdf21760859a4b0d0b0015277640c2ef4c69640e6f" +checksum = "40dd8c92efc296286ce1fbd16657c5dbefff44f1b4ca01cc5f517d8b7b3d3e2e" dependencies = [ "aes", "arbitrary", @@ -5189,7 +5213,7 @@ dependencies = [ "displaydoc", "flate2", "hmac", - "indexmap 2.2.6", + "indexmap 2.3.0", "lzma-rs", "memchr", "pbkdf2", @@ -5227,18 +5251,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.0" +version = "7.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.12+zstd.1.5.6" +version = "2.0.13+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index da8a4ebfa..ee8d04a8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "mistralrs", "mistralrs-bench", "mistralrs-vision", + "mistralrs-quant", ] exclude = [ "mistralrs-paged_attn", @@ -24,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/README.md b/README.md index 727e545c8..4b167807c 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ Mistal.rs supports several model categories: - Please suggest more by raising an issue! - Tool calling: [docs](docs/TOOL_CALLING.md) - Prompt chunking (only without PagedAttention for now): handle larger prompts where the activation size would cause an OOM by sending chunks +- Various quantizations (GGUF, GPTQ, ISQ): [docs](docs/QUANTS.md) This is a demo of interactive mode with streaming running Phi 3 128k mini with quantization via ISQ to Q4K. diff --git a/docs/QUANTS.md b/docs/QUANTS.md new file mode 100644 index 000000000..5c5cb1666 --- /dev/null +++ b/docs/QUANTS.md @@ -0,0 +1,41 @@ +# Quantization in mistral.rs + +Mistral.rs supports the following quantization: +- GGUF/GGML + - Q, K type + - Supported in GGUF/GGML and GGUF/GGML adapter models + - I quants coming! + - CPU, CUDA, Metal (all supported devices) +- GPTQ + - Supported in all plain and adapter models + - CUDA only +- ISQ + - Q, K type GGUF quants + - Supported in all plain and adapter models + - I quants coming! + - GPTQ quants coming! + - CPU, CUDA, Metal (all supported devices) + +## Using a GGUF quantized model +- Use the `gguf` (cli) / `GGUF` (Python) model selector +- Provide the GGUF file + +``` +cargo run --features cuda -- -i gguf -f my-gguf-file.gguf +``` + +## Using ISQ +See the [docs](ISQ.md) + +``` +cargo run --features cuda -- -i --isq Q4K plain -m microsoft/Phi-3-mini-4k-instruct -a phi3 +``` + +## Using a GPTQ quantized model +- Use the `plain` (cli) / `Plain` (Python) model selector +- Provide the model ID for the GPTQ model +- Mistral.rs will automatically detect and use GPTQ quantization. + +``` +cargo run --features cuda -- -i plain -m kaitchup/Phi-3-mini-4k-instruct-gptq-4bit -a phi3 +``` \ No newline at end of file diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 051b608ab..6385d5a8a 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" @@ -71,6 +71,7 @@ base64.workspace = true bytemuck_derive = "1.7.0" plotly = { version = "0.9.0", features = ["kaleido"], optional = true } mistralrs-paged-attn = { version = "0.2.4", path = "../mistralrs-paged-attn", optional = true } +mistralrs-quant = { version = "0.2.0", path = "../mistralrs-quant" } uuid = { version = "1.10.0", features = ["v4"] } schemars = "0.8.21" @@ -78,7 +79,7 @@ schemars = "0.8.21" default = ["plotly"] plotly = ["dep:plotly"] pyo3_macros = ["pyo3"] -cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"] +cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"] cudnn = ["candle-core/cudnn"] metal = ["candle-core/metal", "candle-nn/metal"] flash-attn = ["cuda", "dep:candle-flash-attn"] diff --git a/mistralrs-core/src/cuda/ffi.rs b/mistralrs-core/src/cuda/ffi.rs index 8dc100a18..725f6bb50 100644 --- a/mistralrs-core/src/cuda/ffi.rs +++ b/mistralrs-core/src/cuda/ffi.rs @@ -9,6 +9,7 @@ extern "C" { pub(crate) fn count_nonzero_u8(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32; + pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32; pub(crate) fn nonzero_bf16( d_in: *const c_void, N: u32, @@ -65,6 +66,14 @@ extern "C" { num_dims: u32, d_out: *mut c_void, ); + pub(crate) fn nonzero_i32( + d_in: *const c_void, + N: u32, + num_nonzero: u32, + dims: *const c_void, + num_dims: u32, + d_out: *mut c_void, + ); pub(crate) fn bitwise_and_u8( d_in1: *const c_void, @@ -84,6 +93,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_and_i32( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_or_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -102,6 +117,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_or_i32( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_xor_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -120,4 +141,10 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_xor_i32( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); } diff --git a/mistralrs-core/src/cuda/nonzero_bitwise.cu b/mistralrs-core/src/cuda/nonzero_bitwise.cu index 749c3a346..ef17662a2 100644 --- a/mistralrs-core/src/cuda/nonzero_bitwise.cu +++ b/mistralrs-core/src/cuda/nonzero_bitwise.cu @@ -56,6 +56,7 @@ COUNT_NONZERO_OP(double, f64) COUNT_NONZERO_OP(uint8_t, u8) COUNT_NONZERO_OP(uint32_t, u32) COUNT_NONZERO_OP(int64_t, i64) +COUNT_NONZERO_OP(int32_t, i32) __global__ void transform_indices(const uint32_t *temp_indices, const uint32_t num_nonzero, @@ -126,6 +127,7 @@ NONZERO_OP(double, f64) NONZERO_OP(uint8_t, u8) NONZERO_OP(uint32_t, u32) NONZERO_OP(int64_t, i64) +NONZERO_OP(int32_t, i32) template __global__ void bitwise_and__kernel(const T *d_in1, const T *d_in2, T *d_out, @@ -207,3 +209,4 @@ void bitwise_xor(const T *d_in1, const T *d_in2, T *d_out, int N) { BITWISE_OP(uint8_t, u8) BITWISE_OP(uint32_t, u32) BITWISE_OP(int64_t, i64) +BITWISE_OP(int32_t, i32) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 02bf9790f..0d172890c 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -16,6 +16,7 @@ use candle_core::{ DType, Device, IndexOp, Result, Shape, Tensor, D, }; use candle_nn::{Linear, Module, VarBuilder}; +use mistralrs_quant::QuantMethod; use serde::Deserialize; pub use crate::layers_masker::CausalMasker; @@ -429,6 +430,15 @@ impl MatMul { matmul.forward(x) } } + + /// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result { + if get_use_matmul_via_f16() { + matmul.forward_via_half(x) + } else { + matmul.forward(x) + } + } } /// Computes softmax(QK^T*sqrt(d_k))V diff --git a/mistralrs-core/src/lora/loralinear.rs b/mistralrs-core/src/lora/loralinear.rs index 0dbe9b61c..c4f967dd9 100644 --- a/mistralrs-core/src/lora/loralinear.rs +++ b/mistralrs-core/src/lora/loralinear.rs @@ -1,23 +1,17 @@ -use std::{collections::HashMap, iter::zip, ops::Mul}; +use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc}; -use candle_core::{ - bail, - quantized::{QMatMul, QTensor}, - Module, Result, Tensor, -}; +use candle_core::{bail, DType, Module, Result, Tensor}; use candle_nn::{Linear, VarBuilder}; use either::Either; - -use crate::layers::QLinear; +use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use super::{ apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper, LinearLayerLike, LoraConfig, LoraLinearConfig, Merge, }; -#[derive(Debug)] pub struct LoraLinear { - old: QLinear, + old: Arc, a_adapters: Either, (Tensor, Vec)>, b_adapters: Either, (Tensor, Vec)>, scale_adapters: Vec, @@ -106,7 +100,9 @@ impl LoraLinear { .to_dtype(a_adapters_stack.dtype())?; let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?; Ok(LoraLinear { - old: QLinear::from_parts(old.weight().clone(), old.bias().cloned()), + old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(old.weight().clone(), old.bias().cloned()), + ))?), a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)), b_adapters: Either::Right((b_adapters_stack, b_adapters)), scale_adapters, @@ -116,7 +112,9 @@ impl LoraLinear { }) } else { Ok(LoraLinear { - old: QLinear::from_parts(old.weight().clone(), old.bias().cloned()), + old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(old.weight().clone(), old.bias().cloned()), + ))?), a_adapters: Either::Left(a_adapters), b_adapters: Either::Left(b_adapters), scale_adapters, @@ -176,43 +174,36 @@ impl Merge for LoraLinear { } fn merge_weights(&mut self) -> Result<()> { - match &self.old.inner() { - QMatMul::QTensor(q) => { - let (mut w_base_layer, dtype) = (q.dequantize(&q.device())?, q.dtype()); - for adapter in 0..self.scale_adapters.len() { - w_base_layer = (w_base_layer + self.get_delta_weight(adapter))?; - } - let new_w = QTensor::quantize(&w_base_layer, dtype)?; - self.old = QLinear::from_qparts(new_w, self.old.bias().cloned()); - } - QMatMul::Tensor(w_base_layer) | QMatMul::TensorF16(w_base_layer) => { - let mut w_base_layer = w_base_layer.clone(); - for adapter in 0..self.scale_adapters.len() { - w_base_layer = (w_base_layer + self.get_delta_weight(adapter))?; - } - self.old = QLinear::from_parts(w_base_layer, self.old.bias().cloned()); + let mut w_base_layer: Option = None; + for adapter in 0..self.scale_adapters.len() { + if let Some(w_base_layer) = &mut w_base_layer { + *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?; + } else { + w_base_layer = Some(self.get_delta_weight(adapter)?) } - }; + } + self.old + .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?; self.merged = true; Ok(()) } } impl LinearLayerLike for LoraLinear { + fn inner(&mut self) -> Option<&mut candle_core::quantized::QMatMul> { + Arc::get_mut(&mut self.old).unwrap().get_qmatmul() + } fn bias(&self) -> Option<&Tensor> { - self.old.bias() + unreachable!() } fn bias_mut(&mut self) -> Option<&mut Tensor> { - self.old.bias_mut() + unreachable!() } fn weight(&self) -> &Tensor { unreachable!() } - fn inner(&mut self) -> &mut QMatMul { - self.old.inner() - } - fn is_quant(&self) -> bool { - self.old.is_quant() + fn quantized_act_type(&self) -> Option { + self.old.quantized_act_type() } fn lora_forward( &self, diff --git a/mistralrs-core/src/lora/mod.rs b/mistralrs-core/src/lora/mod.rs index 6e62fb3e0..dba9a230d 100644 --- a/mistralrs-core/src/lora/mod.rs +++ b/mistralrs-core/src/lora/mod.rs @@ -4,7 +4,7 @@ use std::{collections::HashSet, fmt::Debug, sync::Arc}; use candle_core::{ quantized::{QMatMul, QTensor}, - IndexOp, Result, Tensor, D, + DType, IndexOp, Result, Tensor, D, }; use candle_nn::{init, Linear, Module, VarBuilder}; use loralinear::LoraLinear; @@ -97,9 +97,9 @@ fn make_adapter( } /// Any layer that is linear-like. -pub trait LinearLayerLike: Debug + Merge + AdapterSwapper { - fn inner(&mut self) -> &mut QMatMul; - fn is_quant(&self) -> bool; +pub trait LinearLayerLike: Merge + AdapterSwapper { + fn quantized_act_type(&self) -> Option; + fn inner(&mut self) -> Option<&mut QMatMul>; fn is_lora(&self) -> bool; fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; @@ -152,12 +152,12 @@ impl AdapterSwapper for Linear { } impl LinearLayerLike for Linear { - fn inner(&mut self) -> &mut QMatMul { - unreachable!() - } fn bias(&self) -> Option<&Tensor> { self.bias() } + fn inner(&mut self) -> Option<&mut QMatMul> { + None + } fn bias_mut(&mut self) -> Option<&mut Tensor> { unreachable!() } @@ -173,8 +173,8 @@ impl LinearLayerLike for Linear { ) -> Result { self.forward(x) } - fn is_quant(&self) -> bool { - false + fn quantized_act_type(&self) -> Option { + None } fn is_lora(&self) -> bool { false diff --git a/mistralrs-core/src/lora/qloralinear.rs b/mistralrs-core/src/lora/qloralinear.rs index 54a80b1d4..f48ed8c60 100644 --- a/mistralrs-core/src/lora/qloralinear.rs +++ b/mistralrs-core/src/lora/qloralinear.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, iter::zip, ops::Mul}; use candle_core::{ bail, quantized::{QMatMul, QTensor}, - Module, Result, Tensor, + DType, Module, Result, Tensor, }; use candle_nn::{Linear, VarBuilder}; use either::Either; @@ -229,6 +229,9 @@ impl Merge for QLoraLinear { } impl LinearLayerLike for QLoraLinear { + fn inner(&mut self) -> Option<&mut candle_core::quantized::QMatMul> { + Some(&mut self.old) + } fn bias(&self) -> Option<&Tensor> { None } @@ -238,11 +241,8 @@ impl LinearLayerLike for QLoraLinear { fn weight(&self) -> &Tensor { unimplemented!() } - fn inner(&mut self) -> &mut QMatMul { - &mut self.old - } - fn is_quant(&self) -> bool { - true + fn quantized_act_type(&self) -> Option { + Some(DType::F32) } fn lora_forward( &self, diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index 19ef96196..945eb9d0f 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -129,7 +129,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result anyhow::Result anyhow::Result, } impl Config { @@ -90,12 +91,12 @@ impl Module for RmsNorm { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - gate_proj: QLinear, - up_proj: QLinear, - down_proj: QLinear, + gate_proj: Arc, + up_proj: Arc, + down_proj: Arc, act_fn: candle_nn::Activation, params: Vec, } @@ -104,13 +105,31 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; - let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; - let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; + let gate_proj = mistralrs_quant::linear_b( + hidden_sz, + intermediate_sz, + false, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let up_proj = mistralrs_quant::linear_b( + hidden_sz, + intermediate_sz, + false, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let down_proj = mistralrs_quant::linear_b( + intermediate_sz, + hidden_sz, + false, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; Ok(Self { - gate_proj: QLinear::from_linear(gate_proj), - up_proj: QLinear::from_linear(up_proj), - down_proj: QLinear::from_linear(down_proj), + gate_proj, + up_proj, + down_proj, act_fn: cfg.hidden_act()?, params: vec![hidden_sz, intermediate_sz], }) @@ -123,29 +142,50 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = xs.apply(&self.up_proj)?; - let mut res = (lhs * rhs)?.apply(&self.down_proj)?; - if self.gate_proj.is_quant() { + let lhs = MatMul + .qmethod_matmul(&xs, &*self.gate_proj)? + .apply(&self.act_fn)?; + let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?; + let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?; + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } vec![ - self.gate_proj.inner(), - self.up_proj.inner(), - self.down_proj.inner(), + Arc::get_mut(&mut self.gate_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.up_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.down_proj).unwrap().get_qmatmul(), ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } vec![ - self.gate_proj.bias_mut(), - self.up_proj.bias_mut(), - self.down_proj.bias_mut(), + Arc::get_mut(&mut self.gate_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.up_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.down_proj).unwrap().get_bias_mut(), ] } fn clone(&self) -> Box { @@ -156,44 +196,41 @@ impl MlpLayer for MLP { } // gate, up, down fn new_added_delta(&self, deltas: Vec>) -> Result> { - let new_gate = if let Some(ref delta) = deltas[0] { - merge_delta!(self.gate_proj.inner_ref(), delta) + let gate_proj = if let Some(ref delta) = deltas[0] { + self.gate_proj.add_delta_w(delta)? } else { - self.gate_proj.inner_ref().clone() + self.gate_proj.clone() }; - let new_up = if let Some(ref delta) = deltas[1] { - merge_delta!(self.up_proj.inner_ref(), delta) + let up_proj = if let Some(ref delta) = deltas[1] { + self.up_proj.add_delta_w(delta)? } else { - self.up_proj.inner_ref().clone() + self.up_proj.clone() }; - let new_down = if let Some(ref delta) = deltas[2] { - merge_delta!(self.down_proj.inner_ref(), delta) + let down_proj = if let Some(ref delta) = deltas[2] { + self.down_proj.add_delta_w(delta)? } else { - self.down_proj.inner_ref().clone() + self.down_proj.clone() }; Ok(Box::new(Self { - gate_proj: QLinear::from_old_and_qmatmul(new_gate, &self.gate_proj), - up_proj: QLinear::from_old_and_qmatmul(new_up, &self.up_proj), - down_proj: QLinear::from_old_and_qmatmul(new_down, &self.down_proj), + gate_proj, + up_proj, + down_proj, act_fn: self.act_fn, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match self.gate_proj.inner_ref() { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.gate_proj.dtype_and_device() } } struct Attention { - q_proj: QLinear, - k_proj: QLinear, - v_proj: QLinear, - o_proj: QLinear, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -216,15 +253,39 @@ impl Attention { let num_kv_groups = num_heads / num_kv_heads; let head_dim = cfg.head_dim; let bias = cfg.attention_bias; - let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; - let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_b( + hidden_sz, + num_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_b( + num_heads * head_dim, + hidden_sz, + bias, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QLinear::from_linear(q_proj), - k_proj: QLinear::from_linear(k_proj), - v_proj: QLinear::from_linear(v_proj), - o_proj: QLinear::from_linear(o_proj), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, @@ -248,13 +309,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = self.q_proj.forward(&xs)?; - let mut k = self.k_proj.forward(&xs)?; - let mut v = self.v_proj.forward(&xs)?; - if self.q_proj.is_quant() { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -325,16 +386,16 @@ impl Attention { } }; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { attn_output.reshape((b_sz, q_len, ()))? }; - let mut res = attn_output.apply(&self.o_proj)?; - if self.q_proj.is_quant() { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -432,6 +493,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -561,10 +629,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.inner(), Some(i))); - tensors.push((layer.self_attn.k_proj.inner(), Some(i))); - tensors.push((layer.self_attn.v_proj.inner(), Some(i))); - tensors.push((layer.self_attn.o_proj.inner(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp @@ -579,10 +677,40 @@ impl IsqModel for Model { fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.k_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.v_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.o_proj.bias_mut(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + tensors.push(( + Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 3ed16af87..f918a4b35 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_b as linear, Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use crate::{ amoe::{ @@ -12,8 +13,7 @@ use crate::{ }, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{repeat_kv, CausalMasker, MatMul, QLinear}, - merge_delta, + layers::{repeat_kv, CausalMasker, MatMul}, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -48,6 +48,7 @@ pub struct Config { #[serde(default = "default_max_position_embeddings")] pub max_position_embeddings: usize, + pub quantization_config: Option, } impl Config { @@ -93,12 +94,12 @@ impl Module for RmsNorm { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - gate_proj: QLinear, - up_proj: QLinear, - down_proj: QLinear, + gate_proj: Arc, + up_proj: Arc, + down_proj: Arc, act_fn: candle_nn::Activation, params: Vec, } @@ -107,13 +108,31 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; - let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; - let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; + let gate_proj = mistralrs_quant::linear_b( + hidden_sz, + intermediate_sz, + false, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let up_proj = mistralrs_quant::linear_b( + hidden_sz, + intermediate_sz, + false, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let down_proj = mistralrs_quant::linear_b( + intermediate_sz, + hidden_sz, + false, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; Ok(Self { - gate_proj: QLinear::from_linear(gate_proj), - up_proj: QLinear::from_linear(up_proj), - down_proj: QLinear::from_linear(down_proj), + gate_proj, + up_proj, + down_proj, act_fn: cfg.hidden_act()?, params: vec![hidden_sz, intermediate_sz], }) @@ -126,29 +145,50 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = xs.apply(&self.up_proj)?; - let mut res = (lhs * rhs)?.apply(&self.down_proj)?; - if self.gate_proj.is_quant() { + let lhs = MatMul + .qmethod_matmul(&xs, &*self.gate_proj)? + .apply(&self.act_fn)?; + let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?; + let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?; + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } vec![ - self.gate_proj.inner(), - self.up_proj.inner(), - self.down_proj.inner(), + Arc::get_mut(&mut self.gate_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.up_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.down_proj).unwrap().get_qmatmul(), ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } vec![ - self.gate_proj.bias_mut(), - self.up_proj.bias_mut(), - self.down_proj.bias_mut(), + Arc::get_mut(&mut self.gate_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.up_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.down_proj).unwrap().get_bias_mut(), ] } fn clone(&self) -> Box { @@ -159,45 +199,42 @@ impl MlpLayer for MLP { } // gate, up, down fn new_added_delta(&self, deltas: Vec>) -> Result> { - let new_gate = if let Some(ref delta) = deltas[0] { - merge_delta!(self.gate_proj.inner_ref(), delta) + let gate_proj = if let Some(ref delta) = deltas[0] { + self.gate_proj.add_delta_w(delta)? } else { - self.gate_proj.inner_ref().clone() + self.gate_proj.clone() }; - let new_up = if let Some(ref delta) = deltas[1] { - merge_delta!(self.up_proj.inner_ref(), delta) + let up_proj = if let Some(ref delta) = deltas[1] { + self.up_proj.add_delta_w(delta)? } else { - self.up_proj.inner_ref().clone() + self.up_proj.clone() }; - let new_down = if let Some(ref delta) = deltas[2] { - merge_delta!(self.down_proj.inner_ref(), delta) + let down_proj = if let Some(ref delta) = deltas[2] { + self.down_proj.add_delta_w(delta)? } else { - self.down_proj.inner_ref().clone() + self.down_proj.clone() }; Ok(Box::new(Self { - gate_proj: QLinear::from_old_and_qmatmul(new_gate, &self.gate_proj), - up_proj: QLinear::from_old_and_qmatmul(new_up, &self.up_proj), - down_proj: QLinear::from_old_and_qmatmul(new_down, &self.down_proj), + gate_proj, + up_proj, + down_proj, act_fn: self.act_fn, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match self.gate_proj.inner_ref() { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.gate_proj.dtype_and_device() } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { - q_proj: QLinear, - k_proj: QLinear, - v_proj: QLinear, - o_proj: QLinear, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -222,15 +259,39 @@ impl Attention { let num_kv_groups = num_heads / num_kv_heads; let head_dim = cfg.head_dim; let bias = cfg.attention_bias; - let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; - let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_b( + hidden_sz, + num_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_b( + num_heads * head_dim, + hidden_sz, + bias, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QLinear::from_linear(q_proj), - k_proj: QLinear::from_linear(k_proj), - v_proj: QLinear::from_linear(v_proj), - o_proj: QLinear::from_linear(o_proj), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, @@ -261,13 +322,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = self.q_proj.forward(&xs)?; - let mut k = self.k_proj.forward(&xs)?; - let mut v = self.v_proj.forward(&xs)?; - if self.q_proj.is_quant() { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -344,14 +405,16 @@ impl Attention { // Convert to contiguous as matmul doesn't support strided vs for now. let mut attn_output = MatMul.matmul(&att, &v.contiguous()?)?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } - let mut res = attn_output - .transpose(1, 2)? - .reshape((b_sz, q_len, ()))? - .apply(&self.o_proj)?; - if self.q_proj.is_quant() { + attn_output = if attention_mask.is_some() { + attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? + } else { + attn_output.reshape((b_sz, q_len, ()))? + }; + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -468,6 +531,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -603,10 +673,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.inner(), Some(i))); - tensors.push((layer.self_attn.k_proj.inner(), Some(i))); - tensors.push((layer.self_attn.v_proj.inner(), Some(i))); - tensors.push((layer.self_attn.o_proj.inner(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp @@ -621,10 +721,40 @@ impl IsqModel for Model { fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.k_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.v_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.o_proj.bias_mut(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + tensors.push(( + Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 67fc62506..72bdf6e6d 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -1,7 +1,8 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; -use candle_nn::{embedding, linear_no_bias as linear, Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use serde::Deserialize; use std::sync::Arc; @@ -17,7 +18,6 @@ use crate::{ ScaledDotProductAttention, }, layers_masker::PastKvLenCache, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, IsqModel, @@ -39,13 +39,14 @@ pub struct Config { pub rope_theta: f32, pub max_position_embeddings: usize, pub rope_scaling: Option, + pub quantization_config: Option, } struct CausalSelfAttention { - q_proj: QMatMul, - k_proj: QMatMul, - v_proj: QMatMul, - o_proj: QMatMul, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, @@ -67,17 +68,17 @@ impl CausalSelfAttention { kv_cache: &mut crate::pipeline::LayerCaches, metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, ) -> Result { - let (b_sz, seq_len, hidden_size) = x.dims3()?; + let (b_sz, seq_len, _) = x.dims3()?; let original_dtype = x.dtype(); let mut x = x.clone(); - if matches!(self.q_proj, QMatMul::QTensor(_)) { - x = x.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + x = x.to_dtype(t)?; } - let mut q = MatMul.qmatmul(&x, &self.q_proj)?; - let mut k = MatMul.qmatmul(&x, &self.k_proj)?; - let mut v = MatMul.qmatmul(&x, &self.v_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { + let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -151,19 +152,19 @@ impl CausalSelfAttention { } }; - if matches!(self.q_proj, QMatMul::QTensor(_)) { - y = y.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + y = y.to_dtype(t)?; } y = if attention_mask.is_some() { - y.transpose(1, 2)?.reshape((b_sz, seq_len, hidden_size))? + y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))? } else { - y.reshape((b_sz, seq_len, hidden_size))? + y.reshape((b_sz, seq_len, ()))? }; - let mut y = MatMul.qmatmul(&y, &self.o_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { - y = y.to_dtype(original_dtype)?; + let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { + res = res.to_dtype(original_dtype)?; } - Ok(y) + Ok(res) } fn load( @@ -175,15 +176,35 @@ impl CausalSelfAttention { let size_in = cfg.hidden_size; let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_no_bias( + size_in, + size_q, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_no_bias( + size_in, + size_kv, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_no_bias( + size_in, + size_kv, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_no_bias( + size_q, + size_in, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QMatMul::Tensor(q_proj.weight().clone()), - k_proj: QMatMul::Tensor(k_proj.weight().clone()), - v_proj: QMatMul::Tensor(v_proj.weight().clone()), - o_proj: QMatMul::Tensor(o_proj.weight().clone()), + q_proj, + k_proj, + v_proj, + o_proj, num_attention_heads: cfg.num_attention_heads, num_key_value_heads: cfg.num_key_value_heads, head_dim: cfg.hidden_size / cfg.num_attention_heads, @@ -195,11 +216,11 @@ impl CausalSelfAttention { } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { - c_fc1: QMatMul, - c_fc2: QMatMul, - c_proj: QMatMul, + c_fc1: Arc, + c_fc2: Arc, + c_proj: Arc, params: Vec, } @@ -207,13 +228,28 @@ impl Mlp { fn load(vb: VarBuilder, cfg: &Config) -> Result { let h_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + let c_fc1 = mistralrs_quant::linear_no_bias( + h_size, + i_size, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let c_fc2 = mistralrs_quant::linear_no_bias( + h_size, + i_size, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let c_proj = mistralrs_quant::linear_no_bias( + i_size, + h_size, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; Ok(Self { - c_fc1: QMatMul::Tensor(c_fc1.weight().clone()), - c_fc2: QMatMul::Tensor(c_fc2.weight().clone()), - c_proj: QMatMul::Tensor(c_proj.weight().clone()), + c_fc1, + c_fc2, + c_proj, params: vec![h_size, i_size], }) } @@ -225,19 +261,34 @@ impl MlpLayer for Mlp { fn forward(&self, x: &Tensor) -> Result { let original_dtype = x.dtype(); let mut x = x.clone(); - if matches!(self.c_fc1, QMatMul::QTensor(_)) { - x = x.to_dtype(DType::F32)?; + if let Some(t) = self.c_fc1.quantized_act_type() { + x = x.to_dtype(t)?; } - let x = (candle_nn::ops::silu(&MatMul.qmatmul(&x, &self.c_fc1)?)? - * MatMul.qmatmul(&x, &self.c_fc2)?)?; - let mut res = MatMul.qmatmul(&x, &self.c_proj)?; - if matches!(self.c_fc1, QMatMul::QTensor(_)) { + let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? + * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?; + let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?; + if self.c_fc1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![&mut self.c_fc1, &mut self.c_fc2, &mut self.c_proj] + { + let c_fc1 = self.c_fc1.clone().convert_to_isq().unwrap(); + self.c_fc1 = c_fc1; + let c_fc2 = self.c_fc2.clone().convert_to_isq().unwrap(); + self.c_fc2 = c_fc2; + let c_proj = self.c_proj.clone().convert_to_isq().unwrap(); + self.c_proj = c_proj; + } + vec![ + Arc::get_mut(&mut self.c_fc1).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.c_fc2).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.c_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { vec![None, None, None] @@ -251,17 +302,17 @@ impl MlpLayer for Mlp { // c_fc1, c_fc2, c_proj fn new_added_delta(&self, deltas: Vec>) -> Result> { let new_c_fc1 = if let Some(ref delta) = deltas[0] { - merge_delta!(self.c_fc1, delta) + self.c_fc1.add_delta_w(delta)? } else { self.c_fc1.clone() }; let new_c_fc2 = if let Some(ref delta) = deltas[1] { - merge_delta!(self.c_fc2, delta) + self.c_fc2.add_delta_w(delta)? } else { self.c_fc2.clone() }; let new_c_proj = if let Some(ref delta) = deltas[2] { - merge_delta!(self.c_proj, delta) + self.c_proj.add_delta_w(delta)? } else { self.c_proj.clone() }; @@ -275,10 +326,7 @@ impl MlpLayer for Mlp { } fn dtype_device(&self) -> (DType, Device) { - match &self.c_fc1 { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.c_fc1.dtype_and_device() } } @@ -413,6 +461,13 @@ impl Llama { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -423,7 +478,7 @@ impl Llama { cfg.hidden_size, mapper.set_nm_device(vb.pp("model.embed_tokens"), false), )?; - let lm_head = linear( + let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), @@ -497,10 +552,28 @@ impl IsqModel for Llama { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.blocks.iter_mut().enumerate() { - tensors.push((&mut layer.attn.q_proj, Some(i))); - tensors.push((&mut layer.attn.k_proj, Some(i))); - tensors.push((&mut layer.attn.v_proj, Some(i))); - tensors.push((&mut layer.attn.o_proj, Some(i))); + { + let q_proj = layer.attn.q_proj.clone().convert_to_isq().unwrap(); + layer.attn.q_proj = q_proj; + let k_proj = layer.attn.k_proj.clone().convert_to_isq().unwrap(); + layer.attn.k_proj = k_proj; + let v_proj = layer.attn.v_proj.clone().convert_to_isq().unwrap(); + layer.attn.v_proj = v_proj; + let o_proj = layer.attn.o_proj.clone().convert_to_isq().unwrap(); + layer.attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.attn.q_proj).unwrap().get_qmatmul() { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.attn.k_proj).unwrap().get_qmatmul() { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.attn.v_proj).unwrap().get_qmatmul() { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.attn.o_proj).unwrap().get_qmatmul() { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index deb9c04b4..88f9db74b 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -2,7 +2,8 @@ /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; -use candle_nn::{linear_no_bias, Activation, VarBuilder}; +use candle_nn::{Activation, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use std::sync::Arc; use crate::{ @@ -16,7 +17,6 @@ use crate::{ repeat_kv, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, ScaledDotProductAttention, }, layers_masker::PastKvLenCache, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -25,7 +25,7 @@ use crate::{ utils::progress::NiceProgressBar, }; -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, Default)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -40,6 +40,7 @@ pub struct Config { pub(crate) sliding_window: Option, pub(crate) use_flash_attn: bool, pub(crate) head_dim: Option, + pub(crate) quantization_config: Option, } impl Config { @@ -49,12 +50,12 @@ impl Config { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - gate_proj: QMatMul, - up_proj: QMatMul, - down_proj: QMatMul, + gate_proj: Arc, + up_proj: Arc, + down_proj: Arc, act_fn: Activation, params: Vec, } @@ -63,13 +64,28 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let up_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let down_proj = mistralrs_quant::linear_no_bias( + intermediate_sz, + hidden_sz, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; Ok(Self { - gate_proj: QMatMul::Tensor(gate_proj.weight().clone()), - up_proj: QMatMul::Tensor(up_proj.weight().clone()), - down_proj: QMatMul::Tensor(down_proj.weight().clone()), + gate_proj, + up_proj, + down_proj, act_fn: cfg.hidden_act, params: vec![hidden_sz, intermediate_sz], }) @@ -82,22 +98,51 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.gate_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let lhs = MatMul.qmatmul(&xs, &self.gate_proj)?.apply(&self.act_fn)?; - let rhs = MatMul.qmatmul(&xs, &self.up_proj)?; - let mut res = MatMul.qmatmul(&(lhs * rhs)?, &self.down_proj)?; - if matches!(self.gate_proj, QMatMul::QTensor(_)) { + let lhs = MatMul + .qmethod_matmul(&xs, &*self.gate_proj)? + .apply(&self.act_fn)?; + let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?; + let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?; + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj] + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } + vec![ + Arc::get_mut(&mut self.gate_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.up_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.down_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { - vec![None, None, None] + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } + vec![ + Arc::get_mut(&mut self.gate_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.up_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.down_proj).unwrap().get_bias_mut(), + ] } fn clone(&self) -> Box { Box::new(Clone::clone(self)) @@ -107,44 +152,41 @@ impl MlpLayer for MLP { } // gate, up, down fn new_added_delta(&self, deltas: Vec>) -> Result> { - let new_gate = if let Some(ref delta) = deltas[0] { - merge_delta!(self.gate_proj, delta) + let gate_proj = if let Some(ref delta) = deltas[0] { + self.gate_proj.add_delta_w(delta)? } else { self.gate_proj.clone() }; - let new_up = if let Some(ref delta) = deltas[1] { - merge_delta!(self.up_proj, delta) + let up_proj = if let Some(ref delta) = deltas[1] { + self.up_proj.add_delta_w(delta)? } else { self.up_proj.clone() }; - let new_down = if let Some(ref delta) = deltas[2] { - merge_delta!(self.down_proj, delta) + let down_proj = if let Some(ref delta) = deltas[2] { + self.down_proj.add_delta_w(delta)? } else { self.down_proj.clone() }; Ok(Box::new(Self { - gate_proj: new_gate, - up_proj: new_up, - down_proj: new_down, + gate_proj, + up_proj, + down_proj, act_fn: self.act_fn, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match &self.gate_proj { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.gate_proj.dtype_and_device() } } struct Attention { - q_proj: QMatMul, - k_proj: QMatMul, - v_proj: QMatMul, - o_proj: QMatMul, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -167,15 +209,35 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; let head_dim = cfg.head_dim(); - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_heads * head_dim, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_no_bias( + num_heads * head_dim, + hidden_sz, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QMatMul::Tensor(q_proj.weight().clone()), - k_proj: QMatMul::Tensor(k_proj.weight().clone()), - v_proj: QMatMul::Tensor(v_proj.weight().clone()), - o_proj: QMatMul::Tensor(o_proj.weight().clone()), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, @@ -200,13 +262,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.q_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = MatMul.qmatmul(&xs, &self.q_proj)?; - let mut k = MatMul.qmatmul(&xs, &self.k_proj)?; - let mut v = MatMul.qmatmul(&xs, &self.v_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -284,17 +346,16 @@ impl Attention { } }; - if matches!(self.q_proj, QMatMul::QTensor(_)) { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { attn_output.reshape((b_sz, q_len, ()))? }; - - let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -412,6 +473,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -468,7 +536,7 @@ impl Model { cfg.rms_norm_eps, mapper.set_nm_device(vb_m.pp("norm"), false), )?; - let lm_head = linear_no_bias( + let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq), @@ -566,10 +634,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((&mut layer.self_attn.q_proj, Some(i))); - tensors.push((&mut layer.self_attn.k_proj, Some(i))); - tensors.push((&mut layer.self_attn.v_proj, Some(i))); - tensors.push((&mut layer.self_attn.o_proj, Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 1576bd916..f03606f5c 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -4,7 +4,8 @@ /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; -use candle_nn::{linear_no_bias, Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use serde::Deserialize; use std::sync::Arc; @@ -22,7 +23,7 @@ use crate::{ }; /// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -38,18 +39,18 @@ pub struct Config { pub(crate) num_experts_per_tok: usize, pub(crate) num_local_experts: usize, pub(crate) use_flash_attn: bool, + pub(crate) quantization_config: Option, } struct Attention { - q_proj: QMatMul, - k_proj: QMatMul, - v_proj: QMatMul, - o_proj: QMatMul, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, - hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, sliding_window: Option, @@ -68,20 +69,39 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_heads * head_dim, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_no_bias( + num_heads * head_dim, + hidden_sz, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QMatMul::Tensor(q_proj.weight().clone()), - k_proj: QMatMul::Tensor(k_proj.weight().clone()), - v_proj: QMatMul::Tensor(v_proj.weight().clone()), - o_proj: QMatMul::Tensor(o_proj.weight().clone()), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, head_dim, - hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, sliding_window: cfg.sliding_window, @@ -102,13 +122,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.q_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = MatMul.qmatmul(&xs, &self.q_proj)?; - let mut k = MatMul.qmatmul(&xs, &self.k_proj)?; - let mut v = MatMul.qmatmul(&xs, &self.v_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -186,29 +206,27 @@ impl Attention { } }; - if matches!(self.q_proj, QMatMul::QTensor(_)) { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { - attn_output - .transpose(1, 2)? - .reshape(&[b_sz, q_len, self.hidden_size])? + attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { - attn_output.reshape(&[b_sz, q_len, self.hidden_size])? + attn_output.reshape((b_sz, q_len, ()))? }; - let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?; - if matches!(self.q_proj, QMatMul::QTensor(_)) { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct BlockSparseTop2MLP { - w1: QMatMul, - w2: QMatMul, - w3: QMatMul, + w1: Arc, + w2: Arc, + w3: Arc, act_fn: Activation, } @@ -216,13 +234,28 @@ impl BlockSparseTop2MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?; - let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?; - let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?; + let w1 = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("w1"), + )?; + let w2 = mistralrs_quant::linear_no_bias( + intermediate_sz, + hidden_sz, + &cfg.quantization_config, + vb.pp("w2"), + )?; + let w3 = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("w3"), + )?; Ok(Self { - w1: QMatMul::Tensor(w1.weight().clone()), - w2: QMatMul::Tensor(w2.weight().clone()), - w3: QMatMul::Tensor(w3.weight().clone()), + w1, + w2, + w3, act_fn: cfg.hidden_act, }) } @@ -232,29 +265,34 @@ impl Module for BlockSparseTop2MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.w1, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.w1.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let lhs = MatMul.qmatmul(&xs, &self.w1)?.apply(&self.act_fn)?; - let rhs = MatMul.qmatmul(&xs, &self.w3)?; - let mut res = MatMul.qmatmul(&(lhs * rhs)?, &self.w2)?; - if matches!(self.w1, QMatMul::QTensor(_)) { + let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?; + let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?; + let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?; + if self.w1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct SparseMoeBlock { - gate: QMatMul, + gate: Arc, experts: Vec, num_experts_per_tok: usize, } impl SparseMoeBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?; + let gate = mistralrs_quant::linear_no_bias( + cfg.hidden_size, + cfg.num_local_experts, + &cfg.quantization_config, + vb.pp("gate"), + )?; let mut experts = Vec::with_capacity(cfg.num_local_experts); let vb = vb.pp("experts"); for idx in 0..cfg.num_local_experts { @@ -262,7 +300,7 @@ impl SparseMoeBlock { experts.push(expert) } Ok(SparseMoeBlock { - gate: QMatMul::Tensor(gate.weight().clone()), + gate, experts, num_experts_per_tok: cfg.num_experts_per_tok, }) @@ -276,11 +314,11 @@ impl Module for SparseMoeBlock { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.gate, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut router_logits = MatMul.qmatmul(&xs, &self.gate)?; - if matches!(self.gate, QMatMul::QTensor(_)) { + let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?; + if self.gate.quantized_act_type().is_some() { router_logits = router_logits.to_dtype(original_dtype)?; } @@ -432,6 +470,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -488,7 +533,7 @@ impl Model { cfg.rms_norm_eps, mapper.set_nm_device(vb_m.pp("norm"), false), )?; - let lm_head = linear_no_bias( + let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), @@ -563,15 +608,63 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((&mut layer.self_attn.q_proj, Some(i))); - tensors.push((&mut layer.self_attn.k_proj, Some(i))); - tensors.push((&mut layer.self_attn.v_proj, Some(i))); - tensors.push((&mut layer.self_attn.o_proj, Some(i))); - tensors.push((&mut layer.block_sparse_moe.gate, Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + let gate = layer + .block_sparse_moe + .gate + .clone() + .convert_to_isq() + .unwrap(); + layer.block_sparse_moe.gate = gate; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } + if let Some(g) = Arc::get_mut(&mut layer.block_sparse_moe.gate) + .unwrap() + .get_qmatmul() + { + tensors.push((g, Some(i))); + } for expert in &mut layer.block_sparse_moe.experts { - tensors.push((&mut expert.w1, Some(i))); - tensors.push((&mut expert.w2, Some(i))); - tensors.push((&mut expert.w3, Some(i))); + if let Some(w1) = Arc::get_mut(&mut expert.w1).unwrap().get_qmatmul() { + tensors.push((w1, Some(i))); + } + if let Some(w2) = Arc::get_mut(&mut expert.w2).unwrap().get_qmatmul() { + tensors.push((w2, Some(i))); + } + if let Some(w3) = Arc::get_mut(&mut expert.w3).unwrap().get_qmatmul() { + tensors.push((w3, Some(i))); + } } } (tensors, &*self.mapper) diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index e1849dcb7..2ab08dd3d 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -1,14 +1,16 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + /// Phi model. /// https://huggingface.co/microsoft/phi-2 -/// There is an alternative implementation of the phi model in mixformers.rs. /// This corresponds to the model update made with the following commit: /// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 -use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; +use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; use candle_nn::{ - embedding, layer_norm, linear, Activation, Embedding, LayerNorm, RotaryEmbedding, VarBuilder, + embedding, layer_norm, Activation, Embedding, LayerNorm, RotaryEmbedding, VarBuilder, }; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use serde::Deserialize; use crate::{ @@ -18,9 +20,8 @@ use crate::{ }, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{repeat_kv, CausalMasker, QLinear, ScaledDotProductAttention}, + layers::{repeat_kv, CausalMasker, MatMul, QLinear, ScaledDotProductAttention}, layers_masker::PastKvLenCache, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -30,7 +31,7 @@ use crate::{ }; // https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py -#[derive(Debug, Clone, PartialEq, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Default)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -41,11 +42,11 @@ pub struct Config { pub(crate) hidden_act: Activation, pub(crate) max_position_embeddings: usize, pub(crate) layer_norm_eps: f64, - pub(crate) tie_word_embeddings: bool, pub(crate) rope_theta: f32, pub(crate) partial_rotary_factor: f64, pub(crate) qk_layernorm: bool, pub(crate) use_flash_attn: bool, + pub(crate) quantization_config: Option, } impl Config { @@ -58,22 +59,32 @@ impl Config { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - fc1: QLinear, - fc2: QLinear, + fc1: Arc, + fc2: Arc, act: Activation, params: Vec, } impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; - let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + let fc1 = mistralrs_quant::linear( + cfg.hidden_size, + cfg.intermediate_size, + &cfg.quantization_config, + vb.pp("fc1"), + )?; + let fc2 = mistralrs_quant::linear( + cfg.intermediate_size, + cfg.hidden_size, + &cfg.quantization_config, + vb.pp("fc2"), + )?; Ok(Self { - fc1: QLinear::from_linear(fc1), - fc2: QLinear::from_linear(fc2), + fc1, + fc2, // This does not match the mixformers implementation where Gelu is used rather than // GeluNew. act: cfg.hidden_act, @@ -88,20 +99,44 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.fc1.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.fc1.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut res = xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)?; - if self.fc1.is_quant() { + let mut res = MatMul.qmethod_matmul( + &MatMul.qmethod_matmul(&xs, &*self.fc1)?.apply(&self.act)?, + &*self.fc2, + )?; + if self.fc1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![self.fc1.inner(), self.fc2.inner()] + { + let fc1 = self.fc1.clone().convert_to_isq().unwrap(); + self.fc1 = fc1; + let fc2 = self.fc2.clone().convert_to_isq().unwrap(); + self.fc2 = fc2; + } + vec![ + Arc::get_mut(&mut self.fc1).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.fc2).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { - vec![self.fc1.bias_mut(), self.fc2.bias_mut()] + { + let fc1 = self.fc1.clone().convert_to_isq().unwrap(); + self.fc1 = fc1; + let fc2 = self.fc2.clone().convert_to_isq().unwrap(); + self.fc2 = fc2; + } + vec![ + Arc::get_mut(&mut self.fc1).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.fc2).unwrap().get_bias_mut(), + ] } fn clone(&self) -> Box { Box::new(Clone::clone(self)) @@ -112,37 +147,34 @@ impl MlpLayer for MLP { // fc1, fc2 fn new_added_delta(&self, deltas: Vec>) -> Result> { let new_fc1 = if let Some(ref delta) = deltas[0] { - merge_delta!(self.fc1.inner_ref(), delta) + self.fc1.add_delta_w(delta)? } else { - self.fc1.inner_ref().clone() + self.fc1.clone() }; let new_fc2 = if let Some(ref delta) = deltas[1] { - merge_delta!(self.fc2.inner_ref(), delta) + self.fc2.add_delta_w(delta)? } else { - self.fc2.inner_ref().clone() + self.fc2.clone() }; Ok(Box::new(Self { - fc1: QLinear::from_old_and_qmatmul(new_fc1, &self.fc1), - fc2: QLinear::from_old_and_qmatmul(new_fc2, &self.fc2), + fc1: new_fc1, + fc2: new_fc2, act: self.act, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match self.fc1.inner_ref() { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.fc1.dtype_and_device() } } struct Attention { - q_proj: QLinear, - k_proj: QLinear, - v_proj: QLinear, - dense: QLinear, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + dense: Arc, q_layernorm: Option, k_layernorm: Option, rotary_emb: RotaryEmbedding, @@ -163,10 +195,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads(); let head_dim = cfg.head_dim(); - let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; + let q_proj = mistralrs_quant::linear( + cfg.hidden_size, + num_heads * head_dim, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear( + cfg.hidden_size, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear( + cfg.hidden_size, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let dense = mistralrs_quant::linear( + num_heads * head_dim, + cfg.hidden_size, + &cfg.quantization_config, + vb.pp("dense"), + )?; let (q_layernorm, k_layernorm) = if cfg.qk_layernorm { let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?; let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?; @@ -175,10 +227,10 @@ impl Attention { (None, None) }; Ok(Self { - q_proj: QLinear::from_linear(q_proj), - k_proj: QLinear::from_linear(k_proj), - v_proj: QLinear::from_linear(v_proj), - dense: QLinear::from_linear(dense), + q_proj, + k_proj, + v_proj, + dense, q_layernorm, k_layernorm, rotary_emb: rope, @@ -203,13 +255,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = self.q_proj.forward(&xs)?; - let mut k = self.k_proj.forward(&xs)?; - let mut v = self.v_proj.forward(&xs)?; - if self.q_proj.is_quant() { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -294,8 +346,8 @@ impl Attention { } }; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if mask.is_some() { attn_output @@ -304,8 +356,8 @@ impl Attention { } else { attn_output.reshape((b_size, seq_len, ()))? }; - let mut res = attn_output.apply(&self.dense)?; - if self.q_proj.is_quant() { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.dense)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -391,6 +443,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -448,7 +507,7 @@ impl Model { )?; layers.push(layer) } - let lm_head = linear( + let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), @@ -520,10 +579,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((self.lm_head.inner(), None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.inner(), Some(i))); - tensors.push((layer.self_attn.k_proj.inner(), Some(i))); - tensors.push((layer.self_attn.v_proj.inner(), Some(i))); - tensors.push((layer.self_attn.dense.inner(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let dense = layer.self_attn.dense.clone().convert_to_isq().unwrap(); + layer.self_attn.dense = dense; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.dense) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp @@ -538,10 +627,40 @@ impl IsqModel for Model { fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.k_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.v_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.dense.bias_mut(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let dense = layer.self_attn.dense.clone().convert_to_isq().unwrap(); + layer.self_attn.dense = dense; + } + tensors.push(( + Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.dense) + .unwrap() + .get_bias_mut(), + Some(i), + )); tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 8074cc24c..bc611cad2 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -3,7 +3,8 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_no_bias, VarBuilder}; +use candle_nn::VarBuilder; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use std::{collections::HashMap, sync::Arc}; use crate::{ @@ -18,7 +19,6 @@ use crate::{ ScaledDotProductAttention, }, layers_masker::PastKvLenCache, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -46,6 +46,7 @@ pub struct Config { pub use_flash_attn: bool, pub sliding_window: Option, pub original_max_position_embeddings: usize, + pub quantization_config: Option, } impl From for PhiRopeConfig { @@ -67,8 +68,8 @@ impl Config { } struct Attention { - qkv_proj: QMatMul, - o_proj: QMatMul, + qkv_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -90,11 +91,24 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.head_dim(); let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim; - let qkv_proj = linear_no_bias(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?; + + let qkv_proj = mistralrs_quant::linear_no_bias( + cfg.hidden_size, + op_size, + &cfg.quantization_config, + vb.pp("qkv_proj"), + )?; + + let o_proj = mistralrs_quant::linear_no_bias( + num_heads * head_dim, + cfg.hidden_size, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; + Ok(Self { - qkv_proj: QMatMul::Tensor(qkv_proj.weight().clone()), - o_proj: QMatMul::Tensor(o_proj.weight().clone()), + qkv_proj, + o_proj, rotary_emb, num_heads, num_kv_heads, @@ -119,11 +133,11 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.qkv_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.qkv_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut qkv = MatMul.qmatmul(&xs, &self.qkv_proj)?; - if matches!(self.qkv_proj, QMatMul::QTensor(_)) { + let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?; + if self.qkv_proj.quantized_act_type().is_some() { qkv = qkv.to_dtype(original_dtype)?; } let query_pos = self.num_heads * self.head_dim; @@ -196,26 +210,26 @@ impl Attention { )? } }; - if matches!(self.qkv_proj, QMatMul::QTensor(_)) { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.qkv_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { attn_output.reshape((b_sz, q_len, ()))? }; - let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?; - if matches!(self.qkv_proj, QMatMul::QTensor(_)) { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.qkv_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { - gate_up_proj: QMatMul, - down_proj: QMatMul, + gate_up_proj: Arc, + down_proj: Arc, act_fn: candle_nn::Activation, i_size: usize, params: Vec, @@ -225,11 +239,24 @@ impl Mlp { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let gate_up_proj = linear_no_bias(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?; - let down_proj = linear_no_bias(i_size, hidden_size, vb.pp("down_proj"))?; + + let gate_up_proj = mistralrs_quant::linear_no_bias( + hidden_size, + 2 * i_size, + &cfg.quantization_config, + vb.pp("gate_up_proj"), + )?; + + let down_proj = mistralrs_quant::linear_no_bias( + i_size, + hidden_size, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; + Ok(Self { - gate_up_proj: QMatMul::Tensor(gate_up_proj.weight().clone()), - down_proj: QMatMul::Tensor(down_proj.weight().clone()), + gate_up_proj, + down_proj, act_fn: cfg.hidden_act, i_size, params: vec![hidden_size, i_size], @@ -243,21 +270,33 @@ impl MlpLayer for Mlp { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.gate_up_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_up_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let up_states = MatMul.qmatmul(&xs, &self.gate_up_proj)?; + let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?; let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; let up_states = (up_states * gate.apply(&self.act_fn))?; - let mut res = MatMul.qmatmul(&up_states, &self.down_proj)?; - if matches!(self.gate_up_proj, QMatMul::QTensor(_)) { + let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?; + if self.gate_up_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![&mut self.gate_up_proj, &mut self.down_proj] + { + let gate_up_proj = self.gate_up_proj.clone().convert_to_isq().unwrap(); + self.gate_up_proj = gate_up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } + vec![ + Arc::get_mut(&mut self.gate_up_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.down_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { vec![None, None] @@ -271,12 +310,12 @@ impl MlpLayer for Mlp { // gate_up, down fn new_added_delta(&self, deltas: Vec>) -> Result> { let new_gate_up = if let Some(ref delta) = deltas[0] { - merge_delta!(self.gate_up_proj, delta) + self.gate_up_proj.add_delta_w(delta)? } else { self.gate_up_proj.clone() }; let new_down = if let Some(ref delta) = deltas[1] { - merge_delta!(self.down_proj, delta) + self.down_proj.add_delta_w(delta)? } else { self.down_proj.clone() }; @@ -291,10 +330,7 @@ impl MlpLayer for Mlp { } fn dtype_device(&self) -> (DType, Device) { - match &self.gate_up_proj { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.gate_up_proj.dtype_and_device() } } @@ -389,6 +425,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -437,7 +480,7 @@ impl Model { cfg.rms_norm_eps, mapper.set_nm_device(vb_m.pp("norm"), false), )?; - let lm_head = linear_no_bias( + let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), @@ -513,8 +556,24 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((&mut layer.self_attn.qkv_proj, Some(i))); - tensors.push((&mut layer.self_attn.o_proj, Some(i))); + { + let qkv_proj = layer.self_attn.qkv_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.qkv_proj = qkv_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(qkv) = Arc::get_mut(&mut layer.self_attn.qkv_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((qkv, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index 5194989e0..d4307aa1d 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -1,9 +1,12 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + +use candle_core::quantized::QTensor; use candle_core::quantized::{ggml_file, gguf_file}; -use candle_core::quantized::{QMatMul, QTensor}; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, RotaryEmbedding}; +use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; use crate::device_map::DeviceMapper; use crate::layers::{repeat_kv, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention}; @@ -17,28 +20,26 @@ use crate::utils::progress::NiceProgressBar; use crate::DeviceMapMetadata; const MAX_SEQ_LEN: u32 = 4096; -#[derive(Debug, Clone)] struct Mlp { - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, + feed_forward_w1: Arc, + feed_forward_w2: Arc, + feed_forward_w3: Arc, } impl Mlp { fn forward(&self, xs: &Tensor) -> Result { - let w1 = MatMul.qmatmul(xs, &self.feed_forward_w1)?; - let w3 = MatMul.qmatmul(xs, &self.feed_forward_w3)?; + let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?; + let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?; let y = &(candle_nn::ops::silu(&w1)? * w3)?; - MatMul.qmatmul(y, &self.feed_forward_w2) + MatMul.qmethod_matmul(y, &*self.feed_forward_w2) } } -#[derive(Debug, Clone)] enum MlpOrMoe { Mlp(Mlp), MoE { n_expert_used: usize, - feed_forward_gate_inp: QMatMul, + feed_forward_gate_inp: Arc, experts: Vec, }, } @@ -53,7 +54,7 @@ impl MlpOrMoe { } => { let (b_size, seq_len, hidden_dim) = xs.dims3()?; let xs = xs.reshape(((), hidden_dim))?; - let router_logits = feed_forward_gate_inp.forward(&xs)?; + let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?; let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; // In order to extract topk, we extract the data from the tensor and manipulate it @@ -114,10 +115,10 @@ impl MlpOrMoe { } struct LayerWeights { - attention_wq: QMatMul, - attention_wk: QMatMul, - attention_wv: QMatMul, - attention_wo: QMatMul, + attention_wq: Arc, + attention_wk: Arc, + attention_wv: Arc, + attention_wo: Arc, attention_norm: QRmsNorm, mlp_or_moe: MlpOrMoe, ffn_norm: QRmsNorm, @@ -140,9 +141,9 @@ impl LayerWeights { ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = MatMul.qmatmul(x, &self.attention_wq)?; - let k = MatMul.qmatmul(x, &self.attention_wk)?; - let v = MatMul.qmatmul(x, &self.attention_wv)?; + let q = MatMul.qmethod_matmul(x, &*self.attention_wq)?; + let k = MatMul.qmethod_matmul(x, &*self.attention_wk)?; + let v = MatMul.qmethod_matmul(x, &*self.attention_wv)?; let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?; @@ -215,7 +216,7 @@ impl LayerWeights { y.reshape(&[b_sz, seq_len, n_embd])? }; - let y = MatMul.qmatmul(&y, &self.attention_wo)?; + let y = MatMul.qmethod_matmul(&y, &*self.attention_wo)?; Ok(y) } } @@ -224,7 +225,7 @@ pub struct ModelWeights { tok_embeddings: Embedding, layers: Vec, norm: QRmsNorm, - output: QMatMul, + output: Arc, pub device: Device, pub cache: Cache, pub max_seq_len: usize, @@ -261,18 +262,39 @@ impl ModelConfig::FromGGML for ModelWeights { let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; MlpOrMoe::Mlp(Mlp { - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w1), + b: None, + })?), + feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w2), + b: None, + })?), + feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w3), + b: None, + })?), }) }; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq)?, - attention_wk: QMatMul::from_qtensor(attention_wk)?, - attention_wv: QMatMul::from_qtensor(attention_wv)?, - attention_wo: QMatMul::from_qtensor(attention_wo)?, + attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wq), + b: None, + })?), + attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wk), + b: None, + })?), + attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wv), + b: None, + })?), + attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wo), + b: None, + })?), attention_norm: QRmsNorm::new(attention_norm, 1e-5)?, mlp_or_moe, ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?, @@ -287,7 +309,10 @@ impl ModelConfig::FromGGML for ModelWeights { tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), layers, norm, - output: QMatMul::from_qtensor(output)?, + output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(output), + b: None, + })?), device: ct.device.clone(), cache: Cache::new(ct.hparams.n_layer as usize, false), max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml. @@ -437,9 +462,18 @@ impl ModelConfig::FromGGUF for ModelWeights { let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w1), + b: None, + })?), + feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w2), + b: None, + })?), + feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w3), + b: None, + })?), }) } else { let feed_forward_gate_inp = @@ -475,15 +509,24 @@ impl ModelConfig::FromGGUF for ModelWeights { .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up)) { experts.push(Mlp { - feed_forward_w1: QMatMul::from_qtensor(QTensor::quantize( - &ff_w1, gate_type, - )?)?, - feed_forward_w2: QMatMul::from_qtensor(QTensor::quantize( - &ff_w2, down_type, - )?)?, - feed_forward_w3: QMatMul::from_qtensor(QTensor::quantize( - &ff_w3, up_type, - )?)?, + feed_forward_w1: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?), + b: None, + }, + )?), + feed_forward_w2: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?), + b: None, + }, + )?), + feed_forward_w3: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?), + b: None, + }, + )?), }) } } @@ -502,16 +545,34 @@ impl ModelConfig::FromGGUF for ModelWeights { let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w1), + b: None, + }, + )?), + feed_forward_w2: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w2), + b: None, + }, + )?), + feed_forward_w3: Arc::new(GgufMatMul::new( + QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_w3), + b: None, + }, + )?), }) } } } MlpOrMoe::MoE { n_expert_used, - feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(feed_forward_gate_inp), + b: None, + })?), experts, } }; @@ -531,10 +592,22 @@ impl ModelConfig::FromGGUF for ModelWeights { )?), }; layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq)?, - attention_wk: QMatMul::from_qtensor(attention_wk)?, - attention_wv: QMatMul::from_qtensor(attention_wv)?, - attention_wo: QMatMul::from_qtensor(attention_wo)?, + attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wq), + b: None, + })?), + attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wk), + b: None, + })?), + attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wv), + b: None, + })?), + attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(attention_wo), + b: None, + })?), attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?, mlp_or_moe, ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?, @@ -549,7 +622,10 @@ impl ModelConfig::FromGGUF for ModelWeights { tok_embeddings: Embedding::new(tok_embeddings, embedding_length), layers, norm, - output: QMatMul::from_qtensor(output)?, + output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(output), + b: None, + })?), device: device.clone(), cache: Cache::new(block_count, false), max_seq_len, @@ -609,7 +685,7 @@ impl ModelWeights { let layer_in = layer_in.to_device(&self.device)?; let x = self.norm.forward(&layer_in)?; extract_logits( - &MatMul.qmatmul(&x.contiguous()?, &self.output)?, + &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?, context_lens, ) } diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index 827e62f8a..935980ee3 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -1,11 +1,18 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + use candle_core::quantized::gguf_file; +use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm}; +use mistralrs_quant::GgufMatMul; +use mistralrs_quant::QuantMethod; +use mistralrs_quant::QuantMethodConfig; use crate::device_map::DeviceMapper; +use crate::layers::MatMul; use crate::layers::ScaledDotProductAttention; use crate::layers::{repeat_kv, CausalMasker, QLinear}; use crate::paged_attention::AttentionImplementation; @@ -19,21 +26,21 @@ use crate::DeviceMapMetadata; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { - ffn_up: QLinear, - ffn_down: QLinear, + ffn_up: Arc, + ffn_down: Arc, } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.ffn_up)?.gelu()?.apply(&self.ffn_down) + MatMul.qmethod_matmul(&MatMul.qmethod_matmul(xs, &*self.ffn_up)?, &*self.ffn_down) } } struct LayerWeights { - attn_qkv: QLinear, - attn_output: QLinear, + attn_qkv: Arc, + attn_output: Arc, attn_norm: LayerNorm, mlp: Mlp, n_head: usize, @@ -260,7 +267,22 @@ impl ModelConfig::FromGGUF for ModelWeights { let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; - let mlp = Mlp { ffn_up, ffn_down }; + let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else { + unreachable!() + }; + let mlp = Mlp { + ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_up_w, + b: ffn_up.bias().cloned(), + })?), + ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_down_w, + b: ffn_down.bias().cloned(), + })?), + }; let attn_norm = layer_norm( ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?, @@ -278,9 +300,23 @@ impl ModelConfig::FromGGUF for ModelWeights { None, )?), }; + let qkv = QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?; + let out = QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?; + let QMatMul::QTensor(qkv_w) = qkv.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(out_w) = out.inner_ref().clone() else { + unreachable!() + }; layers.push(LayerWeights { - attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, - attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, + attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: qkv_w, + b: qkv.bias().cloned(), + })?), + attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: out_w, + b: out.bias().cloned(), + })?), attn_norm, mlp, n_head: head_count, diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 04e7eb4fe..098a6f4a5 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -1,5 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + use crate::device_map::DeviceMapper; use crate::layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}; use crate::layers_masker::PastKvLenCache; @@ -15,21 +17,22 @@ use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::Embedding; +use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { - ffn_up: QMatMul, - ffn_down: QMatMul, + ffn_up: Arc, + ffn_down: Arc, i_size: usize, } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { - let up_states = MatMul.qmatmul(xs, &self.ffn_up)?; + let up_states = MatMul.qmethod_matmul(xs, &*self.ffn_up)?; let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; let up_states = (up_states * gate.silu()?)?; - MatMul.qmatmul(&up_states, &self.ffn_down) + MatMul.qmethod_matmul(&up_states, &*self.ffn_down) } } @@ -40,8 +43,8 @@ fn rms_norm(w: QTensor, eps: f64) -> Result { } struct LayerWeights { - attn_qkv: QMatMul, - attn_output: QMatMul, + attn_qkv: Arc, + attn_output: Arc, attn_norm: RmsNorm, ffn_norm: RmsNorm, mlp: Mlp, @@ -79,7 +82,7 @@ impl LayerWeights { metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let qkv = MatMul.qmatmul(x, &self.attn_qkv)?; + let qkv = MatMul.qmethod_matmul(x, &*self.attn_qkv)?; let query_pos = self.n_head * self.head_dim; let q = qkv.narrow(D::Minus1, 0, query_pos)?; let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?; @@ -155,7 +158,7 @@ impl LayerWeights { } else { y.reshape(&[b_sz, seq_len, n_embd])? }; - let y = MatMul.qmatmul(&y, &self.attn_output)?; + let y = MatMul.qmethod_matmul(&y, &*self.attn_output)?; Ok(y) } } @@ -288,9 +291,21 @@ impl ModelConfig::FromGGUF for ModelWeights { &format!("{prefix}.ffn_down.weight"), device, )?)?; + let QMatMul::QTensor(ffn_up_w) = ffn_up else { + unreachable!() + }; + let QMatMul::QTensor(ffn_down_w) = ffn_down else { + unreachable!() + }; let mlp = Mlp { - ffn_up, - ffn_down, + ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_up_w, + b: None, + })?), + ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_down_w, + b: None, + })?), i_size, }; let attn_norm = rms_norm( @@ -313,17 +328,31 @@ impl ModelConfig::FromGGUF for ModelWeights { None, )?), }; + let qkv = QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.attn_qkv.weight"), + device, + )?)?; + let out = QMatMul::from_qtensor(ct.tensor( + reader, + &format!("{prefix}.attn_output.weight"), + device, + )?)?; + let QMatMul::QTensor(qkv_w) = qkv.clone() else { + unreachable!() + }; + let QMatMul::QTensor(out_w) = out.clone() else { + unreachable!() + }; layers.push(LayerWeights { - attn_qkv: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_qkv.weight"), - device, - )?)?, - attn_output: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_output.weight"), - device, - )?)?, + attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: qkv_w, + b: None, + })?), + attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: out_w, + b: None, + })?), attn_norm, ffn_norm, mlp, diff --git a/mistralrs-core/src/models/quantized_starcoder2.rs b/mistralrs-core/src/models/quantized_starcoder2.rs index 3f34f1368..7970285ce 100644 --- a/mistralrs-core/src/models/quantized_starcoder2.rs +++ b/mistralrs-core/src/models/quantized_starcoder2.rs @@ -1,5 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use std::sync::Arc; + use crate::device_map::DeviceMapper; use crate::layers::{ repeat_kv, CausalMasker, MatMul, QLinear, RotaryEmbedding, ScaledDotProductAttention, @@ -17,18 +19,22 @@ use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Embedding, LayerNorm}; +use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { - ffn_up: QLinear, - ffn_down: QLinear, + ffn_up: Arc, + ffn_down: Arc, } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.ffn_up)? - .apply(&candle_nn::Activation::GeluPytorchTanh)? - .apply(&self.ffn_down) + MatMul.qmethod_matmul( + &MatMul + .qmethod_matmul(xs, &*self.ffn_up)? + .apply(&candle_nn::Activation::GeluPytorchTanh)?, + &*self.ffn_down, + ) } } @@ -40,10 +46,10 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { } struct LayerWeights { - attn_q: QLinear, - attn_k: QLinear, - attn_v: QLinear, - attn_output: QLinear, + attn_q: Arc, + attn_k: Arc, + attn_v: Arc, + attn_output: Arc, attn_norm: LayerNorm, ffn_norm: LayerNorm, mlp: Mlp, @@ -66,9 +72,9 @@ impl LayerWeights { ) -> Result { let (b_sz, q_len, hidden_size) = x.dims3()?; - let q = self.attn_q.forward(x)?; - let k = self.attn_k.forward(x)?; - let v = self.attn_v.forward(x)?; + let q = MatMul.qmethod_matmul(x, &*self.attn_q)?; + let k = MatMul.qmethod_matmul(x, &*self.attn_k)?; + let v = MatMul.qmethod_matmul(x, &*self.attn_v)?; let mut q = q.reshape((b_sz * q_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * q_len, self.n_kv_head, self.head_dim))?; @@ -141,7 +147,7 @@ impl LayerWeights { y.reshape(&[b_sz, q_len, hidden_size])? }; - self.attn_output.forward(&y) + MatMul.qmethod_matmul(&y, &*self.attn_output) } } @@ -250,7 +256,22 @@ impl ModelConfig::FromGGUF for ModelWeights { let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; - let mlp = Mlp { ffn_up, ffn_down }; + let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else { + unreachable!() + }; + let mlp = Mlp { + ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_up_w, + b: ffn_up.bias().cloned(), + })?), + ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: ffn_down_w, + b: ffn_down.bias().cloned(), + })?), + }; let attn_norm = layer_norm( ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?, @@ -277,11 +298,35 @@ impl ModelConfig::FromGGUF for ModelWeights { None, )?), }; + let QMatMul::QTensor(q_w) = attn_q.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(k_w) = attn_k.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(v_w) = attn_v.inner_ref().clone() else { + unreachable!() + }; + let QMatMul::QTensor(o_w) = attn_output.inner_ref().clone() else { + unreachable!() + }; layers.push(LayerWeights { - attn_q, - attn_k, - attn_v, - attn_output, + attn_q: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: q_w, + b: attn_q.bias().cloned(), + })?), + attn_k: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: k_w, + b: attn_k.bias().cloned(), + })?), + attn_v: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: v_w, + b: attn_v.bias().cloned(), + })?), + attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: o_w, + b: attn_output.bias().cloned(), + })?), attn_norm, ffn_norm, mlp, diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index 0970a2417..982747a86 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -1,7 +1,8 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; -use candle_nn::{linear, linear_no_bias, Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use std::sync::Arc; use crate::{ @@ -11,9 +12,8 @@ use crate::{ }, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{repeat_kv, CausalMasker, MatMul, QLinear, RmsNorm, ScaledDotProductAttention}, + layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, layers_masker::PastKvLenCache, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -22,7 +22,7 @@ use crate::{ utils::progress::NiceProgressBar, }; -#[derive(Debug, Clone, PartialEq, serde::Deserialize, Default)] +#[derive(Debug, Clone, serde::Deserialize, Default)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, @@ -32,21 +32,19 @@ pub struct Config { pub num_key_value_heads: usize, pub max_position_embeddings: usize, pub sliding_window: usize, - pub max_window_layers: usize, - pub tie_word_embeddings: bool, pub rope_theta: f64, pub rms_norm_eps: f64, - pub use_sliding_window: bool, pub hidden_act: Activation, pub use_flash_attn: bool, + pub quantization_config: Option, } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - gate_proj: QMatMul, - up_proj: QMatMul, - down_proj: QMatMul, + gate_proj: Arc, + up_proj: Arc, + down_proj: Arc, act_fn: Activation, params: Vec, } @@ -55,13 +53,28 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("gate_proj"), + )?; + let up_proj = mistralrs_quant::linear_no_bias( + hidden_sz, + intermediate_sz, + &cfg.quantization_config, + vb.pp("up_proj"), + )?; + let down_proj = mistralrs_quant::linear_no_bias( + intermediate_sz, + hidden_sz, + &cfg.quantization_config, + vb.pp("down_proj"), + )?; Ok(Self { - gate_proj: QMatMul::Tensor(gate_proj.weight().clone()), - up_proj: QMatMul::Tensor(up_proj.weight().clone()), - down_proj: QMatMul::Tensor(down_proj.weight().clone()), + gate_proj, + up_proj, + down_proj, act_fn: cfg.hidden_act, params: vec![hidden_sz, intermediate_sz], }) @@ -74,22 +87,51 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if matches!(self.gate_proj, QMatMul::QTensor(_)) { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let lhs = MatMul.qmatmul(&xs, &self.gate_proj)?.apply(&self.act_fn)?; - let rhs = MatMul.qmatmul(&xs, &self.up_proj)?; - let mut res = MatMul.qmatmul(&(lhs * rhs)?, &self.down_proj)?; - if matches!(self.gate_proj, QMatMul::QTensor(_)) { + let lhs = MatMul + .qmethod_matmul(&xs, &*self.gate_proj)? + .apply(&self.act_fn)?; + let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?; + let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?; + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj] + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } + vec![ + Arc::get_mut(&mut self.gate_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.up_proj).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.down_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { - vec![None, None, None] + { + let gate_proj = self.gate_proj.clone().convert_to_isq().unwrap(); + self.gate_proj = gate_proj; + let up_proj = self.up_proj.clone().convert_to_isq().unwrap(); + self.up_proj = up_proj; + let down_proj = self.down_proj.clone().convert_to_isq().unwrap(); + self.down_proj = down_proj; + } + vec![ + Arc::get_mut(&mut self.gate_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.up_proj).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.down_proj).unwrap().get_bias_mut(), + ] } fn clone(&self) -> Box { Box::new(Clone::clone(self)) @@ -99,44 +141,41 @@ impl MlpLayer for MLP { } // gate, up, down fn new_added_delta(&self, deltas: Vec>) -> Result> { - let new_gate = if let Some(ref delta) = deltas[0] { - merge_delta!(self.gate_proj, delta) + let gate_proj = if let Some(ref delta) = deltas[0] { + self.gate_proj.add_delta_w(delta)? } else { self.gate_proj.clone() }; - let new_up = if let Some(ref delta) = deltas[1] { - merge_delta!(self.up_proj, delta) + let up_proj = if let Some(ref delta) = deltas[1] { + self.up_proj.add_delta_w(delta)? } else { self.up_proj.clone() }; - let new_down = if let Some(ref delta) = deltas[2] { - merge_delta!(self.down_proj, delta) + let down_proj = if let Some(ref delta) = deltas[2] { + self.down_proj.add_delta_w(delta)? } else { self.down_proj.clone() }; Ok(Box::new(Self { - gate_proj: new_gate, - up_proj: new_up, - down_proj: new_down, + gate_proj, + up_proj, + down_proj, act_fn: self.act_fn, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match &self.gate_proj { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.gate_proj.dtype_and_device() } } struct Attention { - q_proj: QLinear, - k_proj: QLinear, - v_proj: QLinear, - o_proj: QMatMul, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -158,15 +197,35 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear( + hidden_sz, + num_heads * head_dim, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_no_bias( + num_heads * head_dim, + hidden_sz, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QLinear::from_linear(q_proj), - k_proj: QLinear::from_linear(k_proj), - v_proj: QLinear::from_linear(v_proj), - o_proj: QMatMul::Tensor(o_proj.weight().clone()), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, @@ -190,13 +249,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = self.q_proj.forward(&xs)?; - let mut k = self.k_proj.forward(&xs)?; - let mut v = self.v_proj.forward(&xs)?; - if self.q_proj.is_quant() { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -267,16 +326,16 @@ impl Attention { } }; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { attn_output.reshape((b_sz, q_len, ()))? }; - let mut res = MatMul.qmatmul(&attn_output, &self.o_proj)?; - if self.q_proj.is_quant() { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -374,6 +433,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -430,7 +496,7 @@ impl Model { cfg.rms_norm_eps, mapper.set_nm_device(vb_m.pp("norm"), false), )?; - let lm_head = linear_no_bias( + let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), @@ -505,10 +571,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.inner(), Some(i))); - tensors.push((layer.self_attn.k_proj.inner(), Some(i))); - tensors.push((layer.self_attn.v_proj.inner(), Some(i))); - tensors.push((&mut layer.self_attn.o_proj, Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp @@ -523,9 +619,32 @@ impl IsqModel for Model { fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.k_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.v_proj.bias_mut(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + } + tensors.push(( + Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); tensors.extend( layer .mlp diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index 2626f4b76..db5c014b0 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -1,17 +1,17 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; -use candle_nn::{layer_norm, linear_b, LayerNorm, VarBuilder}; +use candle_nn::{layer_norm, LayerNorm, VarBuilder}; +use mistralrs_quant::{QuantMethod, QuantizedConfig}; use std::sync::Arc; use crate::{ amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp}, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, QLinear, RotaryEmbedding, ScaledDotProductAttention}, + layers::{CausalMasker, MatMul, RotaryEmbedding, ScaledDotProductAttention}, layers_masker::PastKvLenCache, layers_utils::repeat_kv, - merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, @@ -36,13 +36,14 @@ pub struct Config { pub(crate) use_bias: bool, pub(crate) sliding_window: Option, pub(crate) use_flash_attn: bool, + pub(crate) quantization_config: Option, } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { - c_fc: QLinear, - c_proj: QLinear, + c_fc: Arc, + c_proj: Arc, act: candle_nn::Activation, params: Vec, } @@ -50,11 +51,23 @@ struct MLP { impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size); - let c_fc = linear_b(h_size, i_size, cfg.use_bias, vb.pp("c_fc"))?; - let c_proj = linear_b(i_size, h_size, cfg.use_bias, vb.pp("c_proj"))?; + let c_fc = mistralrs_quant::linear_b( + h_size, + i_size, + cfg.use_bias, + &cfg.quantization_config, + vb.pp("c_fc"), + )?; + let c_proj = mistralrs_quant::linear_b( + i_size, + h_size, + cfg.use_bias, + &cfg.quantization_config, + vb.pp("c_proj"), + )?; Ok(Self { - c_fc: QLinear::from_linear(c_fc), - c_proj: QLinear::from_linear(c_proj), + c_fc, + c_proj, act: cfg.hidden_act, params: vec![h_size, i_size], }) @@ -67,23 +80,44 @@ impl MlpLayer for MLP { fn forward(&self, xs: &Tensor) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.c_fc.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.c_fc.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut res = xs - .apply(&self.c_fc)? - .apply(&self.act)? - .apply(&self.c_proj)?; - if self.c_fc.is_quant() { + let mut res = MatMul.qmethod_matmul( + &MatMul.qmethod_matmul(&xs, &*self.c_fc)?.apply(&self.act)?, + &*self.c_proj, + )?; + if self.c_fc.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } fn get_isq_tensors(&mut self) -> Vec<&mut QMatMul> { - vec![self.c_fc.inner(), self.c_proj.inner()] + { + let c_fc = self.c_fc.clone().convert_to_isq().unwrap(); + self.c_fc = c_fc; + let c_proj = self.c_proj.clone().convert_to_isq().unwrap(); + self.c_proj = c_proj; + } + vec![ + Arc::get_mut(&mut self.c_fc).unwrap().get_qmatmul(), + Arc::get_mut(&mut self.c_proj).unwrap().get_qmatmul(), + ] + .into_iter() + .flatten() + .collect::>() } fn get_isq_biases(&mut self) -> Vec> { - vec![self.c_fc.bias_mut(), self.c_proj.bias_mut()] + { + let c_fc = self.c_fc.clone().convert_to_isq().unwrap(); + self.c_fc = c_fc; + let c_proj = self.c_proj.clone().convert_to_isq().unwrap(); + self.c_proj = c_proj; + } + vec![ + Arc::get_mut(&mut self.c_fc).unwrap().get_bias_mut(), + Arc::get_mut(&mut self.c_proj).unwrap().get_bias_mut(), + ] } fn clone(&self) -> Box { Box::new(Clone::clone(self)) @@ -94,37 +128,34 @@ impl MlpLayer for MLP { // c_fc, c_proj fn new_added_delta(&self, deltas: Vec>) -> Result> { let new_c_fc = if let Some(ref delta) = deltas[0] { - merge_delta!(self.c_fc.inner_ref(), delta) + self.c_fc.add_delta_w(delta)? } else { - self.c_fc.inner_ref().clone() + self.c_fc.clone() }; let new_c_proj = if let Some(ref delta) = deltas[1] { - merge_delta!(self.c_proj.inner_ref(), delta) + self.c_proj.add_delta_w(delta)? } else { - self.c_proj.inner_ref().clone() + self.c_proj.clone() }; Ok(Box::new(Self { - c_fc: QLinear::from_old_and_qmatmul(new_c_fc, &self.c_fc), - c_proj: QLinear::from_old_and_qmatmul(new_c_proj, &self.c_proj), + c_fc: new_c_fc, + c_proj: new_c_proj, act: self.act, params: self.params.clone(), })) } fn dtype_device(&self) -> (DType, Device) { - match &self.c_fc.inner_ref() { - QMatMul::QTensor(q) => (DType::F32, q.device()), - QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), - } + self.c_fc.dtype_and_device() } } struct Attention { - q_proj: QLinear, - k_proj: QLinear, - v_proj: QLinear, - o_proj: QLinear, + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, @@ -148,15 +179,39 @@ impl Attention { let num_kv_groups = num_heads / num_kv_heads; let head_dim = hidden_sz / num_heads; let b = cfg.use_bias; - let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?; - let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?; - let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?; - let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?; + let q_proj = mistralrs_quant::linear_b( + hidden_sz, + num_heads * head_dim, + b, + &cfg.quantization_config, + vb.pp("q_proj"), + )?; + let k_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + b, + &cfg.quantization_config, + vb.pp("k_proj"), + )?; + let v_proj = mistralrs_quant::linear_b( + hidden_sz, + num_kv_heads * head_dim, + b, + &cfg.quantization_config, + vb.pp("v_proj"), + )?; + let o_proj = mistralrs_quant::linear_b( + num_heads * head_dim, + hidden_sz, + b, + &cfg.quantization_config, + vb.pp("o_proj"), + )?; Ok(Self { - q_proj: QLinear::from_linear(q_proj), - k_proj: QLinear::from_linear(k_proj), - v_proj: QLinear::from_linear(v_proj), - o_proj: QLinear::from_linear(o_proj), + q_proj, + k_proj, + v_proj, + o_proj, num_heads, num_kv_heads, num_kv_groups, @@ -181,13 +236,13 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } - let mut q = self.q_proj.forward(&xs)?; - let mut k = self.k_proj.forward(&xs)?; - let mut v = self.v_proj.forward(&xs)?; - if self.q_proj.is_quant() { + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -265,16 +320,16 @@ impl Attention { } }; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } attn_output = if attention_mask.is_some() { attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? } else { attn_output.reshape((b_sz, q_len, ()))? }; - let mut res = attn_output.apply(&self.o_proj)?; - if self.q_proj.is_quant() { + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -372,6 +427,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -502,10 +564,40 @@ impl IsqModel for Model { let mut tensors = Vec::new(); tensors.push((&mut self.lm_head, None)); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.inner(), Some(i))); - tensors.push((layer.self_attn.k_proj.inner(), Some(i))); - tensors.push((layer.self_attn.v_proj.inner(), Some(i))); - tensors.push((layer.self_attn.o_proj.inner(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + if let Some(q) = Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((q, Some(i))); + } + if let Some(k) = Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((k, Some(i))); + } + if let Some(b) = Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((b, Some(i))); + } + if let Some(o) = Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_qmatmul() + { + tensors.push((o, Some(i))); + } tensors.extend( layer .mlp @@ -520,10 +612,40 @@ impl IsqModel for Model { fn get_biases(&mut self) -> (Vec<(Option<&mut Tensor>, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push((layer.self_attn.q_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.k_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.v_proj.bias_mut(), Some(i))); - tensors.push((layer.self_attn.o_proj.bias_mut(), Some(i))); + { + let q_proj = layer.self_attn.q_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.q_proj = q_proj; + let k_proj = layer.self_attn.k_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.k_proj = k_proj; + let v_proj = layer.self_attn.v_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.v_proj = v_proj; + let o_proj = layer.self_attn.o_proj.clone().convert_to_isq().unwrap(); + layer.self_attn.o_proj = o_proj; + } + tensors.push(( + Arc::get_mut(&mut layer.self_attn.q_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.k_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.v_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); + tensors.push(( + Arc::get_mut(&mut layer.self_attn.o_proj) + .unwrap() + .get_bias_mut(), + Some(i), + )); tensors.extend( layer .mlp diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 5a5ce2ed4..3d4633265 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -107,6 +107,12 @@ impl CustomOp2 for BitWise { let result = CpuStorage::I64(result); Ok((result, l1.shape().clone())) } + CpuStorage::I32(vs1) => { + let vs2 = s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I32(result); + Ok((result, l1.shape().clone())) + } CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")), @@ -155,6 +161,12 @@ impl CustomOp2 for BitWise { let elem_count = l1.shape().elem_count(); (d_in1_ptr, d_in2_ptr, elem_count) } + DType::I32 => { + let d_in1_ptr = *s1.as_cuda_slice::()?.device_ptr() as *const c_void; + let d_in2_ptr = *s2.as_cuda_slice::()?.device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, d_in2_ptr, elem_count) + } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")); } @@ -250,6 +262,33 @@ impl CustomOp2 for BitWise { }; CudaStorage::wrap_cuda_slice(d_out, dev) } + DType::I32 => { + let d_out = unsafe { dev.alloc::(elem_count) }.w()?; + let d_out_ptr = *d_out.device_ptr() as *mut c_void; + unsafe { + match self.op { + BitWiseOpEnum::And => ffi::bitwise_and_i32( + d_in1_ptr, + d_in2_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + ), + BitWiseOpEnum::Or => ffi::bitwise_or_i32( + d_in1_ptr, + d_in2_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + ), + BitWiseOpEnum::Xor => ffi::bitwise_xor_i32( + d_in1_ptr, + d_in2_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + ), + } + }; + CudaStorage::wrap_cuda_slice(d_out, dev) + } _ => unreachable!(), }; Ok((dst, l1.shape().clone())) @@ -340,6 +379,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n), candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n), candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), + candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), @@ -367,6 +407,9 @@ fn nonzero_cuda( candle_core::DType::I64 => { ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } + candle_core::DType::I32 => { + ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) + } candle_core::DType::BF16 => { ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out) } @@ -395,6 +438,7 @@ impl CustomOp1 for NonZero { let result = match storage { candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout), @@ -420,6 +464,7 @@ impl CustomOp1 for NonZero { let d_in = match storage.dtype() { candle_core::DType::U8 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::U32 => *storage.as_cuda_slice::()?.device_ptr(), + candle_core::DType::I32 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I64 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::BF16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(), diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 49ed54047..9918f8ad5 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -408,7 +408,7 @@ macro_rules! lora_model_loader { .iter() .map(|(_, x)| (*x).to_owned()) .collect::>(), - $dtype, + Some($dtype), $device, $silent, |_| true, diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 3d7031fb3..183eec1b3 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -296,6 +296,8 @@ pub enum QuantizationKind { Ggml, /// GGUF Gguf, + /// GPTQ + Gptq, } #[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 728d0de02..d1a417af0 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -158,7 +158,7 @@ impl NormalLoaderBuilder { self.with_adapter(lora_model_id, lora_order, false, None) } - pub fn build(self, loader: NormalLoaderType) -> Box { + pub fn build(self, loader: NormalLoaderType) -> anyhow::Result> { let loader: Box = match loader { NormalLoaderType::Mistral => Box::new(MistralLoader), NormalLoaderType::Gemma => Box::new(GemmaLoader), @@ -170,7 +170,7 @@ impl NormalLoaderBuilder { NormalLoaderType::Gemma2 => Box::new(Gemma2Loader), NormalLoaderType::Starcoder2 => Box::new(Starcoder2Loader), }; - Box::new(NormalLoader { + Ok(Box::new(NormalLoader { inner: loader, model_id: self.model_id.unwrap(), config: self.config, @@ -181,7 +181,7 @@ impl NormalLoaderBuilder { chat_template: self.chat_template, tokenizer_json: self.tokenizer_json, tgt_non_granular_index: self.tgt_non_granular_index, - }) + })) } } @@ -266,7 +266,7 @@ impl Loader for NormalLoader { let mut model = match self.kind { ModelKind::Normal => normal_model_loader!( paths, - dtype, + Some(dtype), &load_device, config, self.inner, @@ -281,7 +281,7 @@ impl Loader for NormalLoader { adapter: AdapterKind::XLora, } => xlora_model_loader!( paths, - dtype, + Some(dtype), &load_device, config, self.inner, @@ -563,20 +563,21 @@ impl AnyMoePipelineMixin for NormalPipeline { let regex = regex.clone(); let match_regex_clone = match_regex.to_string(); let layers_clone = layers.clone(); - let vb = from_mmaped_safetensors(filenames, vec![], dtype, dev, silent, move |key| { - if regex.is_match(&key) { - // Idx of the last char of the layer id, +1 - // Assumes N.MLP - let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1; - let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap(); - let layer_n = key[first_layer_idx + 1..last_layer_idx] - .parse::() - .unwrap(); - layers_clone.contains(&layer_n) || layers_clone.is_empty() - } else { - false - } - })?; + let vb = + from_mmaped_safetensors(filenames, vec![], Some(dtype), dev, silent, move |key| { + if regex.is_match(&key) { + // Idx of the last char of the layer id, +1 + // Assumes N.MLP + let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1; + let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap(); + let layer_n = key[first_layer_idx + 1..last_layer_idx] + .parse::() + .unwrap(); + layers_clone.contains(&layer_n) || layers_clone.is_empty() + } else { + false + } + })?; vbs.push(vb); } @@ -609,7 +610,7 @@ impl AnyMoePipelineMixin for NormalPipeline { let vb = from_mmaped_safetensors( gate_filenames.clone(), vec![], - dtype, + Some(dtype), dev, silent, |_| true, diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 416716cff..2dff44958 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -10,6 +10,7 @@ use candle_core::Device; use candle_nn::{Activation, VarBuilder}; use either::Either; +use mistralrs_quant::QuantizedConfig; #[cfg(feature = "pyo3_macros")] use pyo3::pyclass; @@ -116,6 +117,7 @@ struct MistralBasicConfig { rope_theta: f64, sliding_window: Option, head_dim: Option, + quantization_config: Option, } impl MistralBasicConfig { @@ -135,6 +137,7 @@ impl MistralBasicConfig { sliding_window: basic_config.sliding_window, use_flash_attn, head_dim: basic_config.head_dim, + quantization_config: basic_config.quantization_config, }) } } @@ -215,6 +218,7 @@ struct GemmaBasicConfig { #[serde(default = "default_max_position_embeddings")] max_position_embeddings: usize, + quantization_config: Option, } impl GemmaBasicConfig { @@ -235,6 +239,7 @@ impl GemmaBasicConfig { attention_bias: basic_config.attention_bias, head_dim: basic_config.head_dim, use_flash_attn, + quantization_config: basic_config.quantization_config, }) } } @@ -309,6 +314,7 @@ struct LlamaBasicConfig { rope_theta: f32, max_position_embeddings: usize, rope_scaling: Option, + quantization_config: Option, } fn default_rope() -> f32 { @@ -332,6 +338,7 @@ impl LlamaBasicConfig { use_flash_attn, max_position_embeddings: basic_config.max_position_embeddings, rope_scaling: basic_config.rope_scaling, + quantization_config: basic_config.quantization_config, }) } } @@ -408,6 +415,7 @@ struct MixtralBasicConfig { sliding_window: Option, num_experts_per_tok: usize, num_local_experts: usize, + quantization_config: Option, } impl MixtralBasicConfig { @@ -428,6 +436,7 @@ impl MixtralBasicConfig { use_flash_attn, num_experts_per_tok: basic_config.num_experts_per_tok, num_local_experts: basic_config.num_local_experts, + quantization_config: basic_config.quantization_config, }) } } @@ -497,10 +506,10 @@ struct Phi2BasicConfig { hidden_act: Activation, max_position_embeddings: usize, layer_norm_eps: f64, - tie_word_embeddings: bool, rope_theta: f32, partial_rotary_factor: f64, qk_layernorm: bool, + quantization_config: Option, } impl Phi2BasicConfig { @@ -517,10 +526,10 @@ impl Phi2BasicConfig { max_position_embeddings: basic_config.max_position_embeddings, rope_theta: basic_config.rope_theta, layer_norm_eps: basic_config.layer_norm_eps, - tie_word_embeddings: basic_config.tie_word_embeddings, partial_rotary_factor: basic_config.partial_rotary_factor, qk_layernorm: basic_config.qk_layernorm, use_flash_attn, + quantization_config: basic_config.quantization_config, }) } } @@ -602,6 +611,7 @@ struct Phi3BasicConfig { max_position_embeddings: usize, original_max_position_embeddings: usize, sliding_window: Option, + quantization_config: Option, } impl Phi3BasicConfig { @@ -624,6 +634,7 @@ impl Phi3BasicConfig { original_max_position_embeddings: basic_config.original_max_position_embeddings, use_flash_attn, sliding_window: basic_config.sliding_window, + quantization_config: basic_config.quantization_config, }) } } @@ -695,12 +706,10 @@ struct Qwen2BasicConfig { num_key_value_heads: usize, max_position_embeddings: usize, sliding_window: usize, - max_window_layers: usize, - tie_word_embeddings: bool, rope_theta: f64, rms_norm_eps: f64, - use_sliding_window: bool, hidden_act: Activation, + quantization_config: Option, } impl Qwen2BasicConfig { @@ -718,10 +727,8 @@ impl Qwen2BasicConfig { rope_theta: basic_config.rope_theta, rms_norm_eps: basic_config.rms_norm_eps, sliding_window: basic_config.sliding_window, - max_window_layers: basic_config.max_window_layers, - tie_word_embeddings: basic_config.tie_word_embeddings, - use_sliding_window: basic_config.use_sliding_window, use_flash_attn, + quantization_config: basic_config.quantization_config, }) } } @@ -851,6 +858,7 @@ struct Starcoder2BasicConfig { rope_theta: f64, use_bias: bool, sliding_window: Option, + quantization_config: Option, } impl Starcoder2BasicConfig { @@ -870,6 +878,7 @@ impl Starcoder2BasicConfig { use_flash_attn, norm_epsilon: basic_config.norm_epsilon, use_bias: basic_config.use_bias, + quantization_config: basic_config.quantization_config, }) } } diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 990a87819..a7ca84730 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -24,8 +24,9 @@ use crate::{ }; // Match files against these, avoids situations like `consolidated.safetensors` -const SAFETENSOR_MATCH: &str = r"model-\d{5}-of-\d{5}"; -const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}"; +const SAFETENSOR_MATCH: &str = r"model-\d{5}-of-\d{5}.safetensors\b"; +const QUANT_SAFETENSOR_MATCH: &str = r"model.safetensors\b"; +const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b"; pub(crate) struct XLoraPaths { pub adapter_configs: Option>, @@ -285,11 +286,15 @@ pub fn get_model_paths( None => { // We only match these patterns for model names let safetensor_match = Regex::new(SAFETENSOR_MATCH)?; + let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?; let pickle_match = Regex::new(PICKLE_MATCH)?; let mut filenames = vec![]; - let listing = api_dir_list!(api, model_id) - .filter(|x| safetensor_match.is_match(x) || pickle_match.is_match(x)); + let listing = api_dir_list!(api, model_id).filter(|x| { + safetensor_match.is_match(x) + || pickle_match.is_match(x) + || quant_safetensor_match.is_match(x) + }); let safetensors = listing .clone() .filter(|x| x.ends_with(".safetensors")) diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index 1bbd04bf0..61e34b06b 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -197,7 +197,7 @@ impl Loader for VisionLoader { let mut model = match self.kind { ModelKind::Normal => vision_normal_model_loader!( paths, - dtype, + Some(dtype), &load_device, config, self.inner, @@ -461,20 +461,21 @@ impl AnyMoePipelineMixin for VisionPipeline { let regex = regex.clone(); let match_regex_clone = match_regex.to_string(); let layers_clone = layers.clone(); - let vb = from_mmaped_safetensors(filenames, vec![], dtype, dev, silent, move |key| { - if regex.is_match(&key) { - // Idx of the last char of the layer id, +1 - // Assumes N.MLP - let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1; - let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap(); - let layer_n = key[first_layer_idx + 1..last_layer_idx] - .parse::() - .unwrap(); - layers_clone.contains(&layer_n) || layers_clone.is_empty() - } else { - false - } - })?; + let vb = + from_mmaped_safetensors(filenames, vec![], Some(dtype), dev, silent, move |key| { + if regex.is_match(&key) { + // Idx of the last char of the layer id, +1 + // Assumes N.MLP + let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1; + let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap(); + let layer_n = key[first_layer_idx + 1..last_layer_idx] + .parse::() + .unwrap(); + layers_clone.contains(&layer_n) || layers_clone.is_empty() + } else { + false + } + })?; vbs.push(vb); } @@ -507,7 +508,7 @@ impl AnyMoePipelineMixin for VisionPipeline { let vb = from_mmaped_safetensors( gate_filenames.clone(), vec![], - dtype, + Some(dtype), dev, silent, |_| true, diff --git a/mistralrs-core/src/toml_selector.rs b/mistralrs-core/src/toml_selector.rs index 88d2d9571..d6b547aef 100644 --- a/mistralrs-core/src/toml_selector.rs +++ b/mistralrs-core/src/toml_selector.rs @@ -319,7 +319,7 @@ fn loader_from_selected( args.tokenizer_json, Some(model_id), ) - .build(arch), + .build(arch)?, TomlModelSelected::XLora { model_id, xlora_model_id, @@ -345,7 +345,7 @@ fn loader_from_selected( args.no_kv_cache, tgt_non_granular_index, ) - .build(arch), + .build(arch)?, TomlModelSelected::Lora { model_id, adapters_model_id, @@ -368,7 +368,7 @@ fn loader_from_selected( .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")), )?, ) - .build(arch), + .build(arch)?, TomlModelSelected::GGUF { tok_model_id, quantized_model_id, diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs index 6a941a25b..c26e4e553 100644 --- a/mistralrs-core/src/utils/model_config.rs +++ b/mistralrs-core/src/utils/model_config.rs @@ -78,7 +78,7 @@ impl<'a> Adapter<'a> { .iter() .map(|(_, x)| (*x).to_owned()) .collect::>(), - candle_core::DType::F32, + Some(candle_core::DType::F32), device, silent, |_| true, diff --git a/mistralrs-core/src/utils/varbuilder_utils.rs b/mistralrs-core/src/utils/varbuilder_utils.rs index a283789f6..ee2edb109 100644 --- a/mistralrs-core/src/utils/varbuilder_utils.rs +++ b/mistralrs-core/src/utils/varbuilder_utils.rs @@ -19,7 +19,7 @@ use super::progress::{Joinable, NonThreadingHandle, Parellelize}; trait TensorLoaderBackend { fn get_names(&self) -> Vec; - fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result; + fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result; } struct SafetensorBackend(MmapedSafetensors); @@ -32,8 +32,17 @@ impl TensorLoaderBackend for SafetensorBackend { .map(|(name, _)| name) .collect::>() } - fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result { - self.0.load(name, device)?.to_dtype(dtype) + fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result { + let t = self.0.load(name, device)?; + if let Some(dtype) = dtype { + if t.dtype() == DType::I32 { + Ok(t) + } else { + t.to_dtype(dtype) + } + } else { + Ok(t) + } } } @@ -43,14 +52,23 @@ impl TensorLoaderBackend for PickleBackend { fn get_names(&self) -> Vec { self.0.tensor_infos().keys().cloned().collect::>() } - fn load_name(&self, name: &str, device: &Device, dtype: DType) -> Result { - self.0 + fn load_name(&self, name: &str, device: &Device, dtype: Option) -> Result { + let t = self + .0 .get(name)? .ok_or(candle_core::Error::Msg(format!( "Could not load tensor {name}" )))? - .to_device(device)? - .to_dtype(dtype) + .to_device(device)?; + if let Some(dtype) = dtype { + if t.dtype() == DType::I32 { + Ok(t) + } else { + t.to_dtype(dtype) + } + } else { + Ok(t) + } } } @@ -60,7 +78,7 @@ impl TensorLoaderBackend for PickleBackend { pub(crate) fn from_mmaped_safetensors<'a>( paths: Vec, xlora_paths: Vec, - dtype: DType, + dtype: Option, device: &Device, silent: bool, predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static, @@ -100,7 +118,13 @@ pub(crate) fn from_mmaped_safetensors<'a>( ws.extend(h.join().unwrap()?); } - Ok(VarBuilder::from_tensors(ws, dtype, device)) + // TODO(EricLBuehler): separation of concerns. + // This is to have WNA16 for GPTQ which is required. No bf16 for GPTQ + Ok(VarBuilder::from_tensors( + ws, + dtype.unwrap_or(DType::F16), + device, + )) } pub(crate) fn load_preload_adapters<'a>( @@ -114,7 +138,7 @@ pub(crate) fn load_preload_adapters<'a>( for (name, (path, config)) in paths { let loader = Common::new(); let loaded_tensors = - loader.load_tensors_from_path(path, device, dtype, silent, |_| true)?; + loader.load_tensors_from_path(path, device, Some(dtype), silent, |_| true)?; map.insert( name.clone(), @@ -136,7 +160,7 @@ trait LoadTensors { &self, path: &PathBuf, device: &Device, - dtype: DType, + dtype: Option, is_silent: bool, predicate: impl Fn(String) -> bool, ) -> Result> { diff --git a/mistralrs-core/src/vision_models/idefics2.rs b/mistralrs-core/src/vision_models/idefics2.rs index ff47c5b40..3311614fa 100644 --- a/mistralrs-core/src/vision_models/idefics2.rs +++ b/mistralrs-core/src/vision_models/idefics2.rs @@ -188,6 +188,7 @@ impl From for mistral::Config { sliding_window: val.sliding_window, use_flash_attn: val.use_flash_attn, head_dim: None, + quantization_config: None, } } } diff --git a/mistralrs-core/src/vision_models/llava/config.rs b/mistralrs-core/src/vision_models/llava/config.rs index 038fb985f..0a854a34d 100644 --- a/mistralrs-core/src/vision_models/llava/config.rs +++ b/mistralrs-core/src/vision_models/llava/config.rs @@ -80,6 +80,7 @@ impl Config { rope_theta: self.text_config.rope_theta, max_position_embeddings: self.text_config.max_position_embeddings, rope_scaling: self.text_config.rope_scaling.clone(), + quantization_config: None, } } @@ -98,6 +99,7 @@ impl Config { sliding_window: self.text_config.sliding_window, use_flash_attn: self.use_flash_attn, head_dim: None, + quantization_config: None, } } diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 737fe6172..cdc170477 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -60,7 +60,7 @@ impl Module for RmsNorm { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Arc, @@ -134,8 +134,8 @@ impl MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let lhs = self .gate_proj @@ -158,14 +158,14 @@ impl MLP { global_scaling_weight, is_scaling_pass, )?; - if self.gate_proj.is_quant() { + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -273,8 +273,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -294,7 +294,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -349,8 +349,8 @@ impl Attention { q_len, )?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?, @@ -358,14 +358,14 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -491,6 +491,13 @@ impl XLoraModel { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -698,8 +705,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -718,8 +725,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -738,8 +745,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -752,36 +759,31 @@ impl XLoraModel { impl IsqModel for XLoraModel { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/gemma2.rs b/mistralrs-core/src/xlora_models/gemma2.rs index 9a4ef3524..d4d82bd2f 100644 --- a/mistralrs-core/src/xlora_models/gemma2.rs +++ b/mistralrs-core/src/xlora_models/gemma2.rs @@ -24,7 +24,7 @@ use crate::{ use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct RmsNorm { weight: Tensor, eps: f64, @@ -54,7 +54,7 @@ impl Module for RmsNorm { } } -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Arc, @@ -128,8 +128,8 @@ impl MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let lhs = self .gate_proj @@ -152,14 +152,14 @@ impl MLP { global_scaling_weight, is_scaling_pass, )?; - if self.gate_proj.is_quant() { + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -279,8 +279,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -300,7 +300,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -377,21 +377,23 @@ impl Attention { // Convert to contiguous as matmul doesn't support strided vs for now. let mut attn_output = MatMul.matmul(&att, &v.contiguous()?)?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } - let res = attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?; - let mut res = - self.o_proj - .lora_forward(&res, scalings, global_scaling_weight, is_scaling_pass)?; - if self.q_proj.is_quant() { + let mut res = self.o_proj.lora_forward( + &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?, + scalings.clone(), + global_scaling_weight, + is_scaling_pass, + )?; + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -541,6 +543,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -722,8 +731,8 @@ impl Model { } let xs = xs.to_device(&self.device)?; let mut xs = xs.apply(&self.norm)?; - if self.lm_head.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?; @@ -775,8 +784,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -795,8 +804,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -815,8 +824,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -829,36 +838,31 @@ impl Model { impl IsqModel for Model { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index a0af0ba17..02de00a9d 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -23,7 +23,7 @@ use crate::{ use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct CausalSelfAttention { q_proj: Arc, k_proj: Arc, @@ -55,8 +55,8 @@ impl CausalSelfAttention { let original_dtype = x.dtype(); let mut x = x.clone(); - if self.q_proj.is_quant() { - x = x.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + x = x.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &x, @@ -76,7 +76,7 @@ impl CausalSelfAttention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -133,8 +133,8 @@ impl CausalSelfAttention { )?; let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; - if self.q_proj.is_quant() { - y = y.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + y = y.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?, @@ -142,7 +142,7 @@ impl CausalSelfAttention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -219,7 +219,7 @@ impl CausalSelfAttention { } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { c_fc1: Arc, c_fc2: Arc, @@ -236,8 +236,8 @@ impl Mlp { ) -> Result { let original_dtype = x.dtype(); let mut x = x.clone(); - if self.c_fc1.is_quant() { - x = x.to_dtype(DType::F32)?; + if let Some(t) = self.c_fc1.quantized_act_type() { + x = x.to_dtype(t)?; } let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward( &x, @@ -256,7 +256,7 @@ impl Mlp { global_scaling_weight, is_scaling_pass, )?; - if self.c_fc1.is_quant() { + if self.c_fc1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -314,7 +314,7 @@ impl Mlp { } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Block { rms_1: RmsNorm, attn: CausalSelfAttention, @@ -518,8 +518,8 @@ impl XLoraLlama { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -538,8 +538,8 @@ impl XLoraLlama { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -558,8 +558,8 @@ impl XLoraLlama { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -579,6 +579,13 @@ impl XLoraLlama { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -695,30 +702,31 @@ impl XLoraLlama { impl IsqModel for XLoraLlama { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.blocks.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push((Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().inner(), Some(i))); - tensors.push((Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().inner(), Some(i))); - tensors.push(( - Arc::get_mut(&mut layer.mlp.c_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.c_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index ff7f8b9c7..9618305bf 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -26,7 +26,7 @@ use crate::{ use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker}; -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Arc, @@ -97,8 +97,8 @@ impl MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let lhs = self .gate_proj @@ -121,14 +121,14 @@ impl MLP { global_scaling_weight, is_scaling_pass, )?; - if self.gate_proj.is_quant() { + if self.gate_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -233,8 +233,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -254,7 +254,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -316,8 +316,8 @@ impl Attention { q_len, )?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?, @@ -325,14 +325,14 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -458,6 +458,13 @@ impl XLoraModel { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -664,8 +671,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -684,8 +691,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -704,8 +711,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -718,36 +725,31 @@ impl XLoraModel { impl IsqModel for XLoraModel { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.gate_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.up_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 07b2d77f0..ef7365419 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -28,7 +28,7 @@ use crate::{ use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -133,8 +133,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -154,7 +154,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -216,8 +216,8 @@ impl Attention { q_len, )?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?, @@ -225,14 +225,14 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct BlockSparseTop2MLP { w1: Arc, w2: Arc, @@ -302,8 +302,8 @@ impl BlockSparseTop2MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.w1.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.w1.quantized_act_type() { + xs = xs.to_dtype(t)?; } let lhs = self .w1 @@ -326,14 +326,14 @@ impl BlockSparseTop2MLP { global_scaling_weight, is_scaling_pass, )?; - if self.w1.is_quant() { + if self.w1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct SparseMoeBlock { gate: Arc, experts: Vec, @@ -398,8 +398,8 @@ impl SparseMoeBlock { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut router_logits = self.gate.lora_forward( &xs, @@ -407,7 +407,7 @@ impl SparseMoeBlock { global_scaling_weight, is_scaling_pass, )?; - if self.gate.is_quant() { + if self.gate.quantized_act_type().is_some() { router_logits = router_logits.to_dtype(original_dtype)?; } @@ -470,7 +470,7 @@ impl SparseMoeBlock { } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct DecoderLayer { self_attn: Attention, block_sparse_moe: SparseMoeBlock, @@ -596,6 +596,13 @@ impl XLoraModel { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -801,8 +808,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -821,8 +828,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -841,8 +848,8 @@ impl XLoraModel { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -855,34 +862,38 @@ impl XLoraModel { impl IsqModel for XLoraModel { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.block_sparse_moe.gate) - .unwrap() - .inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.block_sparse_moe.gate) + .unwrap() + .inner() + { + tensors.push((x, Some(i))); + } for expert in &mut layer.block_sparse_moe.experts { - tensors.push((Arc::get_mut(&mut expert.w1).unwrap().inner(), Some(i))); - tensors.push((Arc::get_mut(&mut expert.w2).unwrap().inner(), Some(i))); - tensors.push((Arc::get_mut(&mut expert.w3).unwrap().inner(), Some(i))); + if let Some(x) = Arc::get_mut(&mut expert.w1).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut expert.w2).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut expert.w3).unwrap().inner() { + tensors.push((x, Some(i))); + } } } (tensors, &*self.mapper) diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index 07f55ff65..95e6fc75c 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -33,7 +33,7 @@ use crate::{ use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { fc1: Arc, @@ -92,8 +92,8 @@ impl MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.fc1.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.fc1.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut res = self.fc2.lora_forward( &self @@ -109,7 +109,7 @@ impl MLP { global_scaling_weight, is_scaling_pass, )?; - if self.fc1.is_quant() { + if self.fc1.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -225,8 +225,8 @@ impl Attention { let (b_size, seq_len, _n_embd) = xs.dims3()?; let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -246,7 +246,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -306,8 +306,8 @@ impl Attention { let mut attn_output = attn_output .transpose(1, 2)? .reshape((b_size, seq_len, ()))?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.dense.lora_forward( &attn_output, @@ -315,7 +315,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -435,6 +435,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -632,8 +639,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -652,8 +659,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -672,8 +679,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -686,26 +693,28 @@ impl Model { impl IsqModel for Model { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.dense).unwrap().inner(), - Some(i), - )); - tensors.push((Arc::get_mut(&mut layer.mlp.fc1).unwrap().inner(), Some(i))); - tensors.push((Arc::get_mut(&mut layer.mlp.fc2).unwrap().inner(), Some(i))); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.dense).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.fc1).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.fc2).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 84807a672..f10fac022 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -29,7 +29,7 @@ use crate::pipeline::Cache; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { qkv_proj: Arc, o_proj: Arc, @@ -109,8 +109,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.qkv_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.qkv_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut qkv = self.qkv_proj.lora_forward( &xs, @@ -118,7 +118,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.qkv_proj.is_quant() { + if self.qkv_proj.quantized_act_type().is_some() { qkv = qkv.to_dtype(original_dtype)?; } let query_pos = self.num_heads * self.head_dim; @@ -176,8 +176,8 @@ impl Attention { q_len, )?; - if self.qkv_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.qkv_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?, @@ -185,14 +185,14 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.qkv_proj.is_quant() { + if self.qkv_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Mlp { gate_up_proj: Arc, down_proj: Arc, @@ -252,8 +252,8 @@ impl Mlp { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.gate_up_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.gate_up_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let up_states = self.gate_up_proj.lora_forward( &xs, @@ -270,14 +270,14 @@ impl Mlp { global_scaling_weight, is_scaling_pass, )?; - if self.gate_up_proj.is_quant() { + if self.gate_up_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct DecoderLayer { self_attn: Attention, mlp: Mlp, @@ -403,6 +403,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -598,8 +605,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -618,8 +625,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -638,8 +645,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -652,24 +659,22 @@ impl Model { impl IsqModel for Model { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.qkv_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.mlp.gate_up_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.qkv_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.gate_up_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.down_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-core/src/xlora_models/starcoder2.rs b/mistralrs-core/src/xlora_models/starcoder2.rs index 91ac480f1..8a1a59ff9 100644 --- a/mistralrs-core/src/xlora_models/starcoder2.rs +++ b/mistralrs-core/src/xlora_models/starcoder2.rs @@ -24,7 +24,7 @@ use crate::{ use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; -#[derive(Debug, Clone)] +#[derive(Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { c_fc: Arc, @@ -84,8 +84,8 @@ impl MLP { ) -> Result { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.c_fc.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.c_fc.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut res = self.c_proj.lora_forward( &self @@ -101,14 +101,14 @@ impl MLP { global_scaling_weight, is_scaling_pass, )?; - if self.c_fc.is_quant() { + if self.c_fc.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) } } -#[derive(Debug, Clone)] +#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -220,8 +220,8 @@ impl Attention { let original_dtype = xs.dtype(); let mut xs = xs.clone(); - if self.q_proj.is_quant() { - xs = xs.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; } let mut q = self.q_proj.lora_forward( &xs, @@ -241,7 +241,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { q = q.to_dtype(original_dtype)?; k = k.to_dtype(original_dtype)?; v = v.to_dtype(original_dtype)?; @@ -303,8 +303,8 @@ impl Attention { q_len, )?; - if self.q_proj.is_quant() { - attn_output = attn_output.to_dtype(DType::F32)?; + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; } let mut res = self.o_proj.lora_forward( &attn_output @@ -314,7 +314,7 @@ impl Attention { global_scaling_weight, is_scaling_pass, )?; - if self.q_proj.is_quant() { + if self.q_proj.quantized_act_type().is_some() { res = res.to_dtype(original_dtype)?; } Ok(res) @@ -446,6 +446,13 @@ impl Model { normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } let mapper = normal_loading_metadata .mapper .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; @@ -648,8 +655,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -668,8 +675,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -688,8 +695,8 @@ impl Model { None, )? .contiguous()?; - if self.lm_head.is_quant() { - res = res.to_dtype(DType::F32)?; + if let Some(t) = self.lm_head.quantized_act_type() { + res = res.to_dtype(t)?; } extract_logits( &self.lm_head.lora_forward(&res, None, 1.0, None)?, @@ -702,29 +709,28 @@ impl Model { impl IsqModel for Model { fn get_matmuls(&mut self) -> (Vec<(&mut QMatMul, Option)>, &dyn DeviceMapper) { let mut tensors = Vec::new(); - tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().inner(), None)); + if let Some(x) = Arc::get_mut(&mut self.lm_head).unwrap().inner() { + tensors.push((x, None)); + } for (i, layer) in self.layers.iter_mut().enumerate() { - tensors.push(( - Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner(), - Some(i), - )); - tensors.push(( - Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner(), - Some(i), - )); - tensors.push((Arc::get_mut(&mut layer.mlp.c_fc).unwrap().inner(), Some(i))); - tensors.push(( - Arc::get_mut(&mut layer.mlp.c_proj).unwrap().inner(), - Some(i), - )); + if let Some(x) = Arc::get_mut(&mut layer.self_attn.q_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.k_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.v_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.self_attn.o_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.c_fc).unwrap().inner() { + tensors.push((x, Some(i))); + } + if let Some(x) = Arc::get_mut(&mut layer.mlp.c_proj).unwrap().inner() { + tensors.push((x, Some(i))); + } } (tensors, &*self.mapper) } diff --git a/mistralrs-paged-attn/src/backend/mod.rs b/mistralrs-paged-attn/src/backend/mod.rs index d621007f0..ffedc0122 100644 --- a/mistralrs-paged-attn/src/backend/mod.rs +++ b/mistralrs-paged-attn/src/backend/mod.rs @@ -26,6 +26,7 @@ pub fn get_or_load_func( let spec = match dtype { DType::U8 => "_u8", DType::U32 => "_u32", + DType::I32 => "_i32", DType::I64 => "_i64", DType::BF16 => "_bf16", DType::F16 => "_f16", diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 4ddf3bc29..18eaa98ce 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.2.4", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e09d7f3", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "57c5599d", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 1ceb50bcb..d135f950d 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -84,7 +84,8 @@ fn parse_which( tokenizer_json, Some(model_id), ) - .build(arch.into()), + .build(arch.into()) + .map_err(|e| PyValueError::new_err(e.to_string()))?, Which::XLora { model_id, xlora_model_id, @@ -111,7 +112,8 @@ fn parse_which( no_kv_cache, tgt_non_granular_index, ) - .build(arch.into()), + .build(arch.into()) + .map_err(|e| PyValueError::new_err(e.to_string()))?, Which::Lora { model_id, tokenizer_json, @@ -135,7 +137,8 @@ fn parse_which( ) .map_err(|e| PyValueError::new_err(e.to_string()))?, ) - .build(arch.into()), + .build(arch.into()) + .map_err(|e| PyValueError::new_err(e.to_string()))?, Which::GGUF { tok_model_id, quantized_model_id, diff --git a/mistralrs-quant/Cargo.toml b/mistralrs-quant/Cargo.toml new file mode 100644 index 000000000..b216a2f35 --- /dev/null +++ b/mistralrs-quant/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "mistralrs-quant" +readme = "README.md" +authors = ["Eric Buehler"] +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +homepage.workspace = true + +[dependencies] +candle-core.workspace = true +candle-nn.workspace = true +half.workspace = true +serde.workspace = true +lazy_static = "1.4" + +[features] +cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"] + +[build-dependencies] +bindgen_cuda = { version = "0.1.5", optional = true } diff --git a/mistralrs-quant/README.md b/mistralrs-quant/README.md new file mode 100644 index 000000000..fcf651012 --- /dev/null +++ b/mistralrs-quant/README.md @@ -0,0 +1,10 @@ +# `mistralrs-quant` + +Quantization techniques for mistral.rs. This implements a common trait for all quantization methods to implement for ease of extension and development. + +Currently supported: +- GGUF: `GgufMatMul` +- Gptq: `GptqMatMul` + +Some kernels are copied or based on implementations in: +- https://github.com/vllm-project/vllm diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs new file mode 100644 index 000000000..a043b38d2 --- /dev/null +++ b/mistralrs-quant/build.rs @@ -0,0 +1,54 @@ +#[cfg(feature = "cuda")] +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + +fn main() { + #[cfg(feature = "cuda")] + { + use std::{path::PathBuf, vec}; + println!("cargo:rerun-if-changed=build.rs"); + let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); + let lib_files = vec!["kernels/gptq/q_gemm.cu"]; + for lib_file in lib_files.iter() { + println!("cargo:rerun-if-changed={lib_file}"); + } + let mut builder = bindgen_cuda::Builder::default() + .kernel_paths(lib_files) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); + + // https://github.com/EricLBuehler/mistral.rs/issues/286 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + builder = builder.arg("--compiler-options"); + builder = builder.arg(cuda_nvcc_flags_env); + } + + let out_file = build_dir.join("libmistralgptq.a"); + builder.build_lib(out_file); + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=mistralgptq"); + println!("cargo:rustc-link-lib=dylib=cudart"); + + let target = std::env::var("TARGET").unwrap(); + if target.contains("msvc") { + // nothing to link to + } else if target.contains("apple") + || target.contains("freebsd") + || target.contains("openbsd") + { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + } +} diff --git a/mistralrs-quant/kernels/gptq/compat.cuh b/mistralrs-quant/kernels/gptq/compat.cuh new file mode 100644 index 000000000..1aa288725 --- /dev/null +++ b/mistralrs-quant/kernels/gptq/compat.cuh @@ -0,0 +1,60 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} + + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif + + #endif +#endif + +#endif diff --git a/mistralrs-quant/kernels/gptq/matrix_view.cuh b/mistralrs-quant/kernels/gptq/matrix_view.cuh new file mode 100644 index 000000000..332f115d7 --- /dev/null +++ b/mistralrs-quant/kernels/gptq/matrix_view.cuh @@ -0,0 +1,290 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } +}; + +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } +}; + +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } +}; + +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } +}; + +#endif diff --git a/mistralrs-quant/kernels/gptq/q_gemm.cu b/mistralrs-quant/kernels/gptq/q_gemm.cu new file mode 100644 index 000000000..38b1a637d --- /dev/null +++ b/mistralrs-quant/kernels/gptq/q_gemm.cu @@ -0,0 +1,1761 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include + +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_2.cuh" +#include "qdq_3.cuh" +#include "qdq_4.cuh" +#include "qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_ROWS_8BIT 24 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#if defined(USE_ROCM) + #include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + #define hipblasHgemm __compat_hipblasHgemm + + // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. + #define rocblas_operation_none HIPBLAS_OP_N + #define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, + const uint32_t*, const half*, + half*, const int, const int, + const int, const int, + const int*); + +template +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; + } + + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; + } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; +} + +extern "C" void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = + pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} + +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +extern "C" void reconstruct_exllama(const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} + +__global__ void gemm_half_q_half_alt_4bit_kernel( + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = + __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - + 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +__global__ void gemm_half_q_half_alt_8bit_kernel( + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +extern "C" void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); +} + +template +__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, const int width, + const int group, + half* __restrict__ out) { + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = + __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } +} + +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + half* __restrict__ out) { + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); + } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } +} + +extern "C" void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, half* out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); +} + +/* +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, half* temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || + (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); + } else { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); + } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } +} +*/ + +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } +} + +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } +} + +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } +} + +__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | + ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | + ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); + } + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; +} + +__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} diff --git a/mistralrs-quant/kernels/gptq/qdq_2.cuh b/mistralrs-quant/kernels/gptq/qdq_2.cuh new file mode 100644 index 000000000..215c51308 --- /dev/null +++ b/mistralrs-quant/kernels/gptq/qdq_2.cuh @@ -0,0 +1,70 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#endif diff --git a/mistralrs-quant/kernels/gptq/qdq_3.cuh b/mistralrs-quant/kernels/gptq/qdq_3.cuh new file mode 100644 index 000000000..c4f7f52aa --- /dev/null +++ b/mistralrs-quant/kernels/gptq/qdq_3.cuh @@ -0,0 +1,144 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#endif diff --git a/mistralrs-quant/kernels/gptq/qdq_4.cuh b/mistralrs-quant/kernels/gptq/qdq_4.cuh new file mode 100644 index 000000000..2053c2f1f --- /dev/null +++ b/mistralrs-quant/kernels/gptq/qdq_4.cuh @@ -0,0 +1,122 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#endif diff --git a/mistralrs-quant/kernels/gptq/qdq_8.cuh b/mistralrs-quant/kernels/gptq/qdq_8.cuh new file mode 100644 index 000000000..1678ea63c --- /dev/null +++ b/mistralrs-quant/kernels/gptq/qdq_8.cuh @@ -0,0 +1,24 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif diff --git a/mistralrs-quant/kernels/gptq/qdq_util.cuh b/mistralrs-quant/kernels/gptq/qdq_util.cuh new file mode 100644 index 000000000..6e5ee6136 --- /dev/null +++ b/mistralrs-quant/kernels/gptq/qdq_util.cuh @@ -0,0 +1,51 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs new file mode 100644 index 000000000..24c5c5599 --- /dev/null +++ b/mistralrs-quant/src/gguf/mod.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use candle_core::{quantized::QMatMul, DType, Result, Tensor}; +use candle_nn::Module; + +use crate::{QuantMethod, QuantMethodConfig}; + +pub struct GgufMatMul { + pub(crate) w: QMatMul, + pub(crate) b: Option, +} + +impl QuantMethod for GgufMatMul { + fn new(method: QuantMethodConfig) -> Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gguf { q_weight, b } => Ok(Self { + w: QMatMul::from_arc(q_weight)?, + b, + }), + QuantMethodConfig::Gptq { + bits: _, + use_exllama: _, + q_weight: _, + gptq_qzeros: _, + gptq_scales: _, + g_idx: _, + bias: _, + } + | QuantMethodConfig::Unquantized(_) => unreachable!(), + } + } + + fn forward(&self, a: &Tensor) -> Result { + let x = self.w.forward(a)?; + if let Some(ref b) = self.b { + x.broadcast_add(b) + } else { + Ok(x) + } + } + + fn forward_via_half(&self, a: &Tensor) -> Result { + let x = self.w.forward_via_f16(a)?; + if let Some(ref b) = self.b { + x.broadcast_add(b) + } else { + Ok(x) + } + } + + fn quantized_act_type(&self) -> Option { + Some(DType::F32) + } + + fn add_delta_w(&self, delta: &Tensor) -> Result> { + match self { + Self { + w: QMatMul::Tensor(w), + b, + } => Ok(Arc::new(Self { + w: QMatMul::Tensor((w + delta)?), + b: b.clone(), + })), + Self { + w: QMatMul::TensorF16(w), + b, + } => Ok(Arc::new(Self { + w: QMatMul::TensorF16((w + delta)?), + b: b.clone(), + })), + Self { + w: QMatMul::QTensor(w), + b, + } => { + let (w, dtype) = (w.dequantize(&w.device())?, w.dtype()); + let w = QMatMul::QTensor(std::sync::Arc::new( + candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?, + )); + Ok(Arc::new(Self { w, b: b.clone() })) + } + } + } + + fn dtype_and_device(&self) -> (DType, candle_core::Device) { + match &self.w { + QMatMul::QTensor(q) => (DType::F32, q.device()), + QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), + } + } + + fn get_qmatmul(&mut self) -> Option<&mut QMatMul> { + Some(&mut self.w) + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + self.b.as_mut() + } + + fn convert_to_isq(self: Arc) -> Result> { + Ok(self) + } +} diff --git a/mistralrs-quant/src/gptq/ffi.rs b/mistralrs-quant/src/gptq/ffi.rs new file mode 100644 index 000000000..f1d605f8a --- /dev/null +++ b/mistralrs-quant/src/gptq/ffi.rs @@ -0,0 +1,56 @@ +use half::f16; + +#[allow(dead_code)] +extern "C" { + pub(crate) fn reconstruct_exllama( + b_q_weight: *const u32, + b_gptq_qzeros: *const u32, + b_gptq_scales: *const f16, + b_q_perm: *const i32, + out: *mut f16, + size_k: i32, + size_n: i32, + groups: i32, + bit: i32, + ); + + pub(crate) fn reconstruct_gptq( + b_q_weight: *const u32, + b_gptq_qzeros: *const u32, + b_gptq_scales: *const f16, + b_q_perm: *const i32, + out: *mut f16, + size_k: i32, + size_n: i32, + groups: i32, + bit: i32, + ); + + pub(crate) fn gemm_half_q_half_cuda_part( + a: *const f16, + b_q_weight: *const u32, + b_gptq_qzeros: *const u32, + b_gptq_scales: *const f16, + b_q_perm: *const i32, + out: *mut f16, + m: i32, + n: i32, + k: i32, + m_count: i32, + groups: i32, + bit: i32, + ); + + pub(crate) fn gemm_half_q_half_alt( + a: *const f16, + b_q_weight: *const u32, + b_gptq_qzeros: *const u32, + b_gptq_scales: *const f16, + b_g_idx: *const i32, + out: *mut f16, + m: i32, + n: i32, + k: i32, + bit: i32, + ); +} diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs new file mode 100644 index 000000000..42c0c179c --- /dev/null +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -0,0 +1,55 @@ +use crate::{QuantMethod, QuantMethodConfig}; +use candle_core::{quantized::QMatMul, DType, Result, Tensor}; +use std::sync::Arc; + +pub struct GptqMatMul; + +impl QuantMethod for GptqMatMul { + fn new(method: QuantMethodConfig) -> Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gptq { + bits: _, + use_exllama: _, + q_weight: _, + gptq_qzeros: _, + gptq_scales: _, + g_idx: _, + bias: _, + } => candle_core::bail!("GPTQ is only supported on CUDA."), + QuantMethodConfig::Gguf { q_weight: _, b: _ } | QuantMethodConfig::Unquantized(_) => { + unreachable!() + } + } + } + + fn forward(&self, _a: &Tensor) -> Result { + todo!() + } + + fn quantized_act_type(&self) -> Option { + todo!() + } + + fn add_delta_w(&self, _delta: &Tensor) -> Result> { + todo!() + } + + fn dtype_and_device(&self) -> (DType, candle_core::Device) { + todo!() + } + + fn get_qmatmul(&mut self) -> Option<&mut QMatMul> { + todo!() + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + todo!() + } + + fn convert_to_isq(self: Arc) -> Result> { + todo!() + } +} diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs new file mode 100644 index 000000000..2a29f4094 --- /dev/null +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -0,0 +1,305 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use candle_core::{ + cuda::{ + cudarc::{ + cublas::{result::hgemm, sys::cublasOperation_t}, + driver::{CudaSlice, DevicePtr}, + }, + CudaDType, CudaStorageSlice, WrapErr, + }, + from_storage_no_op, + quantized::QMatMul, + CudaDevice, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, WithDType, D, +}; +use half::f16; +use lazy_static::lazy_static; + +use crate::{QuantMethod, QuantMethodConfig}; + +use super::ffi::{ + gemm_half_q_half_alt, gemm_half_q_half_cuda_part, reconstruct_exllama, reconstruct_gptq, +}; + +const MAX_Q_GEMM_ROWS_8BIT: i32 = 24; +const MAX_Q_GEMM_ROWS: i32 = 50; +const MAX_ALT_GEMM_ROWS: i32 = 8; +const BLOCK_M_SIZE_MAX: i32 = 8; + +lazy_static! { + static ref TMP_DQS: Mutex>> = Mutex::new(HashMap::new()); +} + +pub struct GptqMatMul { + q_weight: Tensor, // u32 + gptq_qzeros: Tensor, // u32 + gptq_scales: Tensor, // f16 + bias: Tensor, // f16 + g_idx: Tensor, // i32 + bits: i32, + use_exllama: bool, +} + +fn get_cuda_slice(x: &Tensor) -> *const T { + match &*x.storage_and_layout().0 { + Storage::Cuda(a_storage) => *a_storage + .as_cuda_slice::() + .expect("DType is not T") + .device_ptr() as *const T, + _ => panic!("Expected CUDA storage."), + } +} + +fn get_cuda_device(x: &Tensor) -> &CudaDevice { + match x.device() { + Device::Cuda(dev) => dev, + _ => panic!("Expected CUDA device"), + } +} + +impl GptqMatMul { + // https://github.com/vllm-project/vllm/blob/966fe72141e8365721840b7ababfb78601c23ead/csrc/quantization/gptq/q_gemm.cu#L1490 + // https://github.com/vllm-project/vllm/blob/966fe72141e8365721840b7ababfb78601c23ead/csrc/quantization/gptq/q_gemm.cu#L1823 + fn gptq_gemm(&self, a: Tensor, groups: i32, use_exllama: bool) -> Result { + if !a.is_contiguous() { + candle_core::bail!( + "Expected `a` to be contiguous, got strides {:?}", + a.layout().stride() + ) + } + let a_ptr = get_cuda_slice::(&a); + let b_q_weight = get_cuda_slice::(&self.q_weight) as *const u32; + let b_gptq_qzeros = get_cuda_slice::(&self.gptq_qzeros) as *const u32; + let b_gptq_scales = get_cuda_slice::(&self.gptq_scales); + let b_g_idx = get_cuda_slice::(&self.g_idx); + + let dev = get_cuda_device(&a); + + let c_shape = Shape::from_dims(&[a.dims()[0], self.q_weight.dims()[1]]); + + let (m, n, k) = ( + c_shape.dims()[0] as i32, + c_shape.dims()[1] as i32, + a.dims()[1] as i32, + ); + + let c = unsafe { dev.alloc::(c_shape.elem_count()).w()? }; + + let c_ptr = *c.device_ptr() as *mut f16; + + let len = (self.q_weight.dims()[0] * 32 / self.bits as usize) * self.q_weight.dims()[1]; + let temp_dq_ptr = *TMP_DQS.try_lock().unwrap().get(&len).unwrap().device_ptr() as *mut f16; + + let use_reconstruct = if use_exllama { + (self.bits == 8 && m > MAX_Q_GEMM_ROWS_8BIT) || (self.bits != 8 && m > MAX_Q_GEMM_ROWS) + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + self.bits < 4 || m > MAX_ALT_GEMM_ROWS + }; + + if use_reconstruct { + // Reconstruct FP16 matrix, then cuBLAS + + let cublas_handle = match a.device() { + Device::Cuda(dev) => dev.cublas_handle(), + _ => unreachable!(), // invariant enforced earlier + }; + + let reconstruct_kernel = if use_exllama { + reconstruct_exllama + } else { + reconstruct_gptq + }; + unsafe { + reconstruct_kernel( + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + temp_dq_ptr, + k, + n, + groups, + self.bits, + ) + }; + + let alpha = f16::from_f32_const(1.0); + let beta = f16::from_f32_const(0.0); + + unsafe { + hgemm( + *cublas_handle.handle(), + cublasOperation_t::CUBLAS_OP_N, + cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + &alpha, + temp_dq_ptr as *const _, + n, + a_ptr as *const _, + k, + &beta, + c_ptr, + n, + ) + .w()? + }; + } else if use_exllama { + let max_chunks = m / BLOCK_M_SIZE_MAX; + let last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + let last_chunk_size = m - last_chunk; + + if max_chunks > 0 { + unsafe { + gemm_half_q_half_cuda_part( + a_ptr as *const _, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + c_ptr, + last_chunk, + n, + k, + BLOCK_M_SIZE_MAX, + groups, + self.bits, + ) + } + } + if last_chunk_size > 0 { + unsafe { + gemm_half_q_half_cuda_part( + a_ptr.add((last_chunk * k) as usize), + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + c_ptr.add((last_chunk * n) as usize), + last_chunk_size, + n, + k, + last_chunk_size, + groups, + self.bits, + ) + } + } + } else { + unsafe { + gemm_half_q_half_alt( + a_ptr as *const _, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + c_ptr, + m, + n, + k, + self.bits, + ) + } + } + + let storage = CudaStorage { + slice: CudaStorageSlice::F16(c), + device: dev.clone(), + }; + let storage = Storage::Cuda(storage); + + Ok(from_storage_no_op(storage, c_shape, false)) + } +} + +impl QuantMethod for GptqMatMul { + fn new(method: QuantMethodConfig) -> Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gptq { + bits, + use_exllama, + q_weight, + gptq_qzeros, + gptq_scales, + g_idx, + bias, + } => { + let dev = get_cuda_device(&g_idx); + let len = (q_weight.dims()[0] * 32 / bits as usize) * q_weight.dims()[1]; + // SAFETY: used in the kernel as a tmp space, just preallocating it here. + if !TMP_DQS.lock().unwrap().contains_key(&len) { + TMP_DQS + .lock() + .unwrap() + .insert(len, unsafe { dev.alloc::(len).w()? }); + } + Ok(Self { + q_weight, + gptq_qzeros, + gptq_scales, + g_idx, + bits, + use_exllama, + bias, + }) + } + QuantMethodConfig::Gguf { q_weight: _, b: _ } | QuantMethodConfig::Unquantized(_) => { + unreachable!() + } + } + } + + fn forward(&self, a: &Tensor) -> Result { + // https://github.com/vllm-project/vllm/blob/ba991d5c84adbc0685075af88333c688ddb06011/vllm/model_executor/layers/quantization/gptq.py#L200 + let out_shape = Shape::from_dims( + &[ + &a.dims()[..a.dims().len() - 1], + &[self.q_weight.dim(D::Minus1)?], + ] + .concat(), + ); + let reshaped_a = a.reshape(((), a.dim(D::Minus1)?))?; + if !reshaped_a.device().is_cuda() { + candle_core::bail!("Expected CUDA input to GptqMatMul"); + } + let out = self.gptq_gemm( + reshaped_a, + self.gptq_qzeros.dim(0)? as i32, + self.use_exllama, + )?; + out.reshape(out_shape)?.broadcast_add(&self.bias) + } + + fn quantized_act_type(&self) -> Option { + Some(DType::F16) + } + + fn add_delta_w(&self, _delta: &Tensor) -> Result> { + candle_core::bail!("GPTQ quantization does not support adding weight delta.") + } + + fn dtype_and_device(&self) -> (DType, Device) { + (self.q_weight.dtype(), self.q_weight.device().clone()) + } + + fn get_qmatmul(&mut self) -> Option<&mut QMatMul> { + None + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + None + } + + fn convert_to_isq(self: Arc) -> Result> { + candle_core::bail!("GPTQ quantization does not support ISQ.") + } +} diff --git a/mistralrs-quant/src/gptq/mod.rs b/mistralrs-quant/src/gptq/mod.rs new file mode 100644 index 000000000..fcd1e0eea --- /dev/null +++ b/mistralrs-quant/src/gptq/mod.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "cuda")] +mod ffi; +#[cfg(not(feature = "cuda"))] +mod gptq_cpu; +#[cfg(feature = "cuda")] +mod gptq_cuda; + +#[cfg(not(feature = "cuda"))] +pub use gptq_cpu::GptqMatMul; +#[cfg(feature = "cuda")] +pub use gptq_cuda::GptqMatMul; diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs new file mode 100644 index 000000000..6c854b2ff --- /dev/null +++ b/mistralrs-quant/src/lib.rs @@ -0,0 +1,189 @@ +use std::{fmt::Display, sync::Arc}; + +use candle_core::{ + quantized::{QMatMul, QTensor}, + DType, Device, Result, Tensor, +}; + +mod gguf; +mod gptq; +mod unquantized; + +pub use gguf::GgufMatMul; +pub use gptq::GptqMatMul; +pub use unquantized::UnquantLinear; + +use candle_nn::{Linear, VarBuilder}; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize, Default)] +pub enum QuantMethodType { + #[default] + #[serde(rename = "gptq")] + Gptq, +} + +impl Display for QuantMethodType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Gptq => write!(f, "GPTQ"), + } + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct QuantizedConfig { + pub bits: usize, + pub quant_method: QuantMethodType, + pub group_size: usize, +} + +#[derive(Debug, Clone)] +pub enum QuantMethodConfig { + Gptq { + bits: i32, + use_exllama: bool, + q_weight: Tensor, + gptq_qzeros: Tensor, + gptq_scales: Tensor, + g_idx: Tensor, + bias: Tensor, + }, + Gguf { + q_weight: Arc, + b: Option, + }, + Unquantized(Linear), +} + +/// Quantized method for a quantized matmul. +pub trait QuantMethod: Send + Sync { + fn new(method: QuantMethodConfig) -> Result + where + Self: Sized; + + /// Compute matmul of `self` and `a`. `self` should contain the weights. + fn forward(&self, a: &Tensor) -> Result; + + /// Compute matmul of `self` and `a`. `self` should contain the weights. + /// This may go via half precision if it is supported. + fn forward_via_half(&self, a: &Tensor) -> Result { + self.forward(a) + } + + /// If a quantized method, return the activation dtype. + fn quantized_act_type(&self) -> Option; + + /// Weight dtype and device + fn dtype_and_device(&self) -> (DType, Device); + + /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha. + fn add_delta_w(&self, delta: &Tensor) -> Result>; + + /// If the quant is backed by a qmatmul. + fn get_qmatmul(&mut self) -> Option<&mut QMatMul>; + + /// If the quant is backed by a qmatmul. + fn get_bias_mut(&mut self) -> Option<&mut Tensor>; + + /// Convert this layer to an ISQ-able layer if possible. + fn convert_to_isq(self: Arc) -> Result>; +} + +macro_rules! pack_factor { + ($bits:expr) => { + 32 / $bits + }; +} + +pub fn linear_no_bias( + in_dim: usize, + out_dim: usize, + config: &Option, + vb: VarBuilder, +) -> Result> { + let layer = if let Some(quant_conf) = &config { + match quant_conf.quant_method { + QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + } + } else { + let layer = candle_nn::linear_no_bias(in_dim, out_dim, vb)?; + + let layer = ::new(QuantMethodConfig::Unquantized(layer))?; + Arc::new(layer) as Arc + }; + Ok(layer) +} + +pub fn linear( + in_dim: usize, + out_dim: usize, + config: &Option, + vb: VarBuilder, +) -> Result> { + let layer = if let Some(quant_conf) = &config { + match quant_conf.quant_method { + QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + } + } else { + let layer = candle_nn::linear(in_dim, out_dim, vb)?; + + let layer = ::new(QuantMethodConfig::Unquantized(layer))?; + Arc::new(layer) as Arc + }; + Ok(layer) +} + +pub fn linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + config: &Option, + vb: VarBuilder, +) -> Result> { + if bias { + linear(in_dim, out_dim, config, vb) + } else { + linear_no_bias(in_dim, out_dim, config, vb) + } +} + +pub fn gptq_linear( + in_dim: usize, + out_dim: usize, + config: &QuantizedConfig, + vb: VarBuilder, +) -> Result> { + let qweight = vb.get_with_hints_dtype( + (in_dim / pack_factor!(config.bits), out_dim), + "qweight", + Default::default(), + DType::I32, + )?; + let scale_and_zero_size = in_dim / config.group_size; + let qzeros = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim / pack_factor!(config.bits)), + "qzeros", + Default::default(), + DType::I32, + )?; + let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?; + let scales = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim), + "scales", + Default::default(), + DType::F16, + )?; + let bias = vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?; + + let config = QuantMethodConfig::Gptq { + bits: config.bits as i32, + use_exllama: false, + q_weight: qweight, + gptq_qzeros: qzeros, + gptq_scales: scales, + g_idx, + bias, + }; + Ok(Arc::new(GptqMatMul::new(config)?)) +} diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs new file mode 100644 index 000000000..a7db48985 --- /dev/null +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use candle_core::{quantized::QMatMul, DType, Result, Tensor}; +use candle_nn::{Linear, Module}; + +use crate::{GgufMatMul, QuantMethod, QuantMethodConfig}; + +pub struct UnquantLinear(Linear); + +impl QuantMethod for UnquantLinear { + fn new(method: QuantMethodConfig) -> candle_core::Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gguf { q_weight: _, b: _ } + | QuantMethodConfig::Gptq { + bits: _, + use_exllama: _, + q_weight: _, + gptq_qzeros: _, + gptq_scales: _, + g_idx: _, + bias: _, + } => unreachable!(), + QuantMethodConfig::Unquantized(l) => Ok(Self(l)), + } + } + + fn forward(&self, a: &Tensor) -> Result { + self.0.forward(a) + } + + fn quantized_act_type(&self) -> Option { + None + } + + fn add_delta_w(&self, delta: &Tensor) -> Result> { + Ok(Arc::new(Self(Linear::new( + (self.0.weight() + delta)?, + self.0.bias().cloned(), + )))) + } + + fn dtype_and_device(&self) -> (DType, candle_core::Device) { + (self.0.weight().dtype(), self.0.weight().device().clone()) + } + + fn get_qmatmul(&mut self) -> Option<&mut QMatMul> { + None + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + None + } + + fn convert_to_isq(self: Arc) -> Result> { + let w = self.0.weight().clone(); + let b = self.0.bias().cloned(); + + Ok(Arc::new(GgufMatMul { + w: QMatMul::Tensor(w), + b, + })) + } +} diff --git a/mistralrs/examples/anymoe/main.rs b/mistralrs/examples/anymoe/main.rs index 33c6475e1..6d59c9fff 100644 --- a/mistralrs/examples/anymoe/main.rs +++ b/mistralrs/examples/anymoe/main.rs @@ -33,7 +33,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; let loader: Box = Box::new(AnyMoeLoader { target: loader, config: AnyMoeConfig { diff --git a/mistralrs/examples/anymoe_lora/main.rs b/mistralrs/examples/anymoe_lora/main.rs index 531a14de8..a606510c0 100644 --- a/mistralrs/examples/anymoe_lora/main.rs +++ b/mistralrs/examples/anymoe_lora/main.rs @@ -33,7 +33,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; let loader: Box = Box::new(AnyMoeLoader { target: loader, config: AnyMoeConfig { diff --git a/mistralrs/examples/gemma2/main.rs b/mistralrs/examples/gemma2/main.rs index 90c263e92..0837d9057 100644 --- a/mistralrs/examples/gemma2/main.rs +++ b/mistralrs/examples/gemma2/main.rs @@ -32,7 +32,7 @@ fn setup() -> anyhow::Result> { None, Some("google/gemma-2-9b-it".to_string()), ) - .build(NormalLoaderType::Gemma2); + .build(NormalLoaderType::Gemma2)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/grammar/main.rs b/mistralrs/examples/grammar/main.rs index 1f707c5fb..1eae9150d 100644 --- a/mistralrs/examples/grammar/main.rs +++ b/mistralrs/examples/grammar/main.rs @@ -32,7 +32,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/isq/main.rs b/mistralrs/examples/isq/main.rs index 917edafb0..a9832382e 100644 --- a/mistralrs/examples/isq/main.rs +++ b/mistralrs/examples/isq/main.rs @@ -33,7 +33,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/lora/main.rs b/mistralrs/examples/lora/main.rs index c0204d8a3..389000df0 100644 --- a/mistralrs/examples/lora/main.rs +++ b/mistralrs/examples/lora/main.rs @@ -39,7 +39,7 @@ fn setup() -> anyhow::Result> { panic!("Could not load ordering file at my-ordering-file.json") }))?, ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/lora_activation/main.rs b/mistralrs/examples/lora_activation/main.rs index 8a163d269..b2b20aa9f 100644 --- a/mistralrs/examples/lora_activation/main.rs +++ b/mistralrs/examples/lora_activation/main.rs @@ -39,7 +39,7 @@ fn setup() -> anyhow::Result> { panic!("Could not load ordering file at my-ordering-file.json") }))?, ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/paged_attn/main.rs b/mistralrs/examples/paged_attn/main.rs index 2896c90cb..71f408c86 100644 --- a/mistralrs/examples/paged_attn/main.rs +++ b/mistralrs/examples/paged_attn/main.rs @@ -39,7 +39,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/simple/main.rs b/mistralrs/examples/simple/main.rs index e78a4d998..6b45c68e5 100644 --- a/mistralrs/examples/simple/main.rs +++ b/mistralrs/examples/simple/main.rs @@ -32,7 +32,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/tools/main.rs b/mistralrs/examples/tools/main.rs index a1347785e..8f5687912 100644 --- a/mistralrs/examples/tools/main.rs +++ b/mistralrs/examples/tools/main.rs @@ -34,7 +34,7 @@ fn setup() -> anyhow::Result> { None, Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()), ) - .build(NormalLoaderType::Llama); + .build(NormalLoaderType::Llama)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None, diff --git a/mistralrs/examples/xlora/main.rs b/mistralrs/examples/xlora/main.rs index 6853b0173..85a902ca6 100644 --- a/mistralrs/examples/xlora/main.rs +++ b/mistralrs/examples/xlora/main.rs @@ -41,7 +41,7 @@ fn setup() -> anyhow::Result> { false, None, ) - .build(NormalLoaderType::Mistral); + .build(NormalLoaderType::Mistral)?; // Load, into a Pipeline let pipeline = loader.load_model_from_hf( None,