diff --git a/.gitignore b/.gitignore index 9a2aada80..8b9bf1609 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ /target .ruff_cache .vscode -*.a \ No newline at end of file +*.a +.DS_Store +architecture.md +.gitignore diff --git a/Cargo.lock b/Cargo.lock index f2ec794e4..dc83283c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" dependencies = [ "backtrace", ] @@ -168,9 +168,9 @@ checksum = "5b8a30a44e99a1c83ccb2a6298c563c888952a1c9134953db26876528f84c93a" [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -191,9 +191,9 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "axum" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec" dependencies = [ "async-trait", "axum-core", @@ -225,9 +225,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00" dependencies = [ "async-trait", "bytes", @@ -238,7 +238,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.1", "tower-layer", "tower-service", "tracing", @@ -386,14 +386,14 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "accelerate-src", "byteorder", @@ -420,7 +420,7 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -431,7 +431,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -439,7 +439,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "metal", "once_cell", @@ -450,7 +450,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "accelerate-src", "candle-core", @@ -467,9 +467,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.18" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "shlex", ] @@ -538,9 +538,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3" dependencies = [ "clap_builder", "clap_derive", @@ -548,9 +548,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b" dependencies = [ "anstream", "anstyle", @@ -560,9 +560,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck", "proc-macro2", @@ -576,29 +576,6 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" -[[package]] -name = "cli-table" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b53f9241f288a7b12c56565f04aaeaeeab6b8923d42d99255d4ca428b4d97f89" -dependencies = [ - "cli-table-derive", - "csv", - "termcolor", - "unicode-width", -] - -[[package]] -name = "cli-table-derive" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e83a93253aaae7c74eb7428ce4faa6e219ba94886908048888701819f82fb94" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "color_quant" version = "1.1.0" @@ -1090,9 +1067,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" dependencies = [ "simd-adler32", ] @@ -1632,9 +1609,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ "bytes", "futures-channel", @@ -1645,16 +1622,15 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", "tower-service", "tracing", ] [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1861,9 +1837,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.158" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "libloading" @@ -2025,9 +2001,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.2.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad" +checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e" dependencies = [ "serde", "serde_json", @@ -2035,9 +2011,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.2.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b" +checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" dependencies = [ "minijinja", "serde", @@ -2110,21 +2086,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "mistralrs-bench" -version = "0.3.0" -dependencies = [ - "anyhow", - "candle-core", - "clap", - "cli-table", - "mistralrs-core", - "serde", - "serde_json", - "tokio", - "tracing", -] - [[package]] name = "mistralrs-core" version = "0.3.0" @@ -2706,26 +2667,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-project" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.77", -] - [[package]] name = "pin-project-lite" version = "0.2.14" @@ -2740,9 +2681,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "png" @@ -2759,9 +2700,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" [[package]] name = "ppv-lite86" @@ -2850,9 +2791,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" dependencies = [ "anyhow", "cfg-if", @@ -2881,9 +2822,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" dependencies = [ "once_cell", "target-lexicon", @@ -2891,9 +2832,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" dependencies = [ "libc", "pyo3-build-config", @@ -2901,9 +2842,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2913,9 +2854,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" dependencies = [ "heck", "proc-macro2", @@ -3103,9 +3044,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.4" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853" +checksum = "355ae415ccd3a04315d3f8246e86d67689ea74d88d915576e1589a351062a13b" dependencies = [ "bitflags 2.6.0", ] @@ -3434,9 +3375,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.1" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3518,9 +3459,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] @@ -3832,9 +3773,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" +checksum = "4ff6c40d3aedb5e06b57c6f669ad17ab063dd1e63d977c6a88e7f4dfa4f04020" dependencies = [ "filetime", "libc", @@ -3860,29 +3801,20 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", @@ -4053,9 +3985,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.20" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "serde", @@ -4066,14 +3998,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.4.13" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" dependencies = [ "futures-core", "futures-util", - "pin-project", "pin-project-lite", + "sync_wrapper 0.1.2", "tokio", "tower-layer", "tower-service", @@ -4242,9 +4174,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] @@ -4266,9 +4198,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode_categories" @@ -4538,9 +4470,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.5" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] @@ -4781,9 +4713,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 6c47d33e3..b1b70326e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,12 +4,13 @@ members = [ "mistralrs-core", "mistralrs-pyo3", "mistralrs", - "mistralrs-bench", + #"mistralrs-bench", "mistralrs-vision", "mistralrs-quant", ] exclude = [ "mistralrs-paged_attn", + "mistralrs-bench", ] resolver = "2" @@ -25,8 +26,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 321c1d1fa..1a9fdddc3 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" diff --git a/mistralrs-core/build.rs b/mistralrs-core/build.rs index 1fae6e92a..5dc2e6ae8 100644 --- a/mistralrs-core/build.rs +++ b/mistralrs-core/build.rs @@ -28,6 +28,7 @@ fn main() { // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } diff --git a/mistralrs-core/src/cuda/ffi.rs b/mistralrs-core/src/cuda/ffi.rs index f945b0ad1..fce6031d6 100644 --- a/mistralrs-core/src/cuda/ffi.rs +++ b/mistralrs-core/src/cuda/ffi.rs @@ -11,6 +11,7 @@ extern "C" { pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32; + pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; pub(crate) fn nonzero_bf16( d_in: *const c_void, N: u32, @@ -83,6 +84,14 @@ extern "C" { num_dims: u32, d_out: *mut c_void, ); + pub(crate) fn nonzero_i16( + d_in: *const c_void, + N: u32, + num_nonzero: u32, + dims: *const c_void, + num_dims: u32, + d_out: *mut c_void, + ); pub(crate) fn bitwise_and_u8( d_in1: *const c_void, @@ -108,6 +117,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_and_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_or_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -132,6 +147,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_or_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_xor_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -156,9 +177,16 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_xor_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); // Linked to in mistralrs-quant pub(crate) fn leftshift_u8(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_u32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i64(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); + pub(crate) fn leftshift_i16(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); } diff --git a/mistralrs-core/src/exl2/mod.rs b/mistralrs-core/src/exl2/mod.rs new file mode 100644 index 000000000..f1b165ed5 --- /dev/null +++ b/mistralrs-core/src/exl2/mod.rs @@ -0,0 +1,87 @@ +use anyhow::{Context, Result}; +use std::{num::NonZeroUsize, str::FromStr}; +use strum::EnumString; + +use crate::{pipeline::QuantizationKind, Loader, ModelDType, ModelKind, Topology}; + +pub const EXL2_MULTI_FILE_DELIMITER: &str = " "; + +#[derive(Debug, EnumString, Clone, Copy)] +#[strum(serialize_all = "kebab-case")] +pub enum EXL2Architecture { + Llama, + Mpt, + Gptneox, + Gptj, + Gpt2, + Bloom, + Falcon, + Mamba, + Rwkv, + Phi2, + Phi3, + Starcoder2, +} + +// Wraps from_str() for some convenience: +// - Case-insensitive variant matching (TODO: is this desirable?) +// - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332 +impl EXL2Architecture { + pub fn from_value + std::fmt::Display>(value: T) -> Result { + Self::from_str(&value.as_ref().to_ascii_lowercase()) + .with_context(|| format!("Unknown EXL2 architecture `{value}`")) + .map_err(anyhow::Error::msg) + } +} + +pub struct EXL2LoaderBuilder { + model_id: Option, + quantized_model_id: String, + quantized_filenames: Vec, + kind: ModelKind, + config: EXL2SpecificConfig, +} + +pub struct EXL2SpecificConfig { + pub topology: Option, + pub gpu_split: Option, + pub length: Option, + pub rope_scale: Option, + pub rope_alpha: Option, + pub no_flash_attn: bool, + pub no_xformers: bool, + pub no_sdpa: bool, + pub low_mem: bool, + pub experts_per_token: Option, + pub load_q4: bool, + pub fast_safetensors: bool, + pub ignore_compatibility: bool, + pub chunk_size: Option, +} + +impl EXL2LoaderBuilder { + pub fn new( + chat_template: Option, + tok_model_id: Option, + quantized_model_id: String, + quantized_filenames: Vec, + config: EXL2SpecificConfig, + ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Exl2, + }; + + Self { + model_id: tok_model_id, + quantized_model_id, + quantized_filenames, + kind, + config, + } + } + + pub fn build(self) -> Result> { + // Implement the loading logic for EXL2 models here + todo!("Implement EXL2 model loading") + } +} diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 2f3e151f3..0d1ab78bc 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -39,6 +39,7 @@ mod amoe; mod cublaslt; #[cfg(not(all(feature = "cuda", target_family = "unix")))] mod dummy_paged_attention; +mod exl2; mod gguf; pub mod layers; mod layers_masker; diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index b94369023..dfc5d431c 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -3,7 +3,9 @@ use std::{ num::NonZeroUsize, }; +use crate::exl2::{EXL2LoaderBuilder, EXL2SpecificConfig}; use crate::{ + exl2::EXL2_MULTI_FILE_DELIMITER, get_toml_selected_model_dtype, pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig}, DiffusionLoaderBuilder, DiffusionSpecificConfig, GGUFSpecificConfig, Loader, ModelDType, @@ -57,6 +59,7 @@ pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option { match model { ModelSelected::Plain { .. } | ModelSelected::Lora { .. } + | ModelSelected::EXL2 { .. } | ModelSelected::GGUF { .. } | ModelSelected::LoraGGUF { .. } | ModelSelected::GGML { .. } @@ -87,6 +90,7 @@ pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result { | ModelSelected::VisionPlain { dtype, .. } | ModelSelected::DiffusionPlain { dtype, .. } => Ok(*dtype), ModelSelected::GGUF { .. } + | ModelSelected::EXL2 { .. } | ModelSelected::LoraGGUF { .. } | ModelSelected::GGML { .. } | ModelSelected::LoraGGML { .. } @@ -228,6 +232,52 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result EXL2LoaderBuilder::new( + args.chat_template, + tok_model_id, + quantized_model_id, + quantized_filename + .split(EXL2_MULTI_FILE_DELIMITER) + .map(ToOwned::to_owned) + .collect::>(), + EXL2SpecificConfig { + topology: Topology::from_option_path(topology)?, + gpu_split, + length, + rope_scale, + rope_alpha, + no_flash_attn, + no_xformers, + no_sdpa, + low_mem, + experts_per_token, + load_q4, + fast_safetensors, + ignore_compatibility, + chunk_size, + }, + ) + .build()?, ModelSelected::XLoraGGUF { tok_model_id, quantized_model_id, diff --git a/mistralrs-core/src/model_selected.rs b/mistralrs-core/src/model_selected.rs index 4573ee61d..94d7932dc 100644 --- a/mistralrs-core/src/model_selected.rs +++ b/mistralrs-core/src/model_selected.rs @@ -391,4 +391,80 @@ pub enum ModelSelected { #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)] dtype: ModelDType, }, + + /// Select an EXL2 model + EXL2 { + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. + #[arg(short, long)] + tok_model_id: Option, + + /// Quantized model ID to find the `quantized_filename`. + /// This may be a HF hub repo or a local path. + #[arg(short = 'm', long)] + quantized_model_id: String, + + /// Quantized filename(s). + /// May be a single filename, or use a delimiter of " " (a single space) for multiple files. + #[arg(short = 'f', long)] + quantized_filename: String, + + /// Path to a topology YAML file. + #[arg(long)] + topology: Option, + + // Specific EXL2 args + /// "auto", or VRAM allocation per GPU in GB + #[arg(short, long)] + gpu_split: Option, + + /// Maximum sequence length + #[arg(short, long)] + length: Option, + + /// RoPE scaling factor + #[arg(short, long)] + rope_scale: Option, + + /// RoPE alpha value (NTK) + #[arg(short, long)] + rope_alpha: Option, + + /// Disable Flash Attention + #[arg(long, action)] + no_flash_attn: bool, + + /// Disable xformers, an alternative plan of flash attn for older devices + #[arg(long, action)] + no_xformers: bool, + + /// Disable Torch SDPA + #[arg(long, action)] + no_sdpa: bool, + + /// Enable VRAM optimizations, potentially trading off speed + #[arg(short, long, action)] + low_mem: bool, + + /// Override MoE model's default number of experts per token + #[arg(short, long)] + experts_per_token: Option, + + /// Load weights in Q4 mode + #[arg(long, action)] + load_q4: bool, + + /// Use alternative safetensors loader (with direct I/O when available) + #[arg(long, action)] + fast_safetensors: bool, + + /// Do not override model config options in case of compatibility issues + #[arg(short, long, action)] + ignore_compatibility: bool, + + /// Chunk size ('input length') + #[arg(long)] + chunk_size: Option, + }, } diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 0d7b5321d..739974ab1 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -119,6 +119,12 @@ impl CustomOp2 for BitWise { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } + CpuStorage::I16(vs1) => { + let vs2 = s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I16(result); + Ok((result, l1.shape().clone())) + } CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")), @@ -393,6 +399,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), + candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), @@ -481,6 +488,7 @@ impl CustomOp1 for NonZero { let d_in = match storage.dtype() { candle_core::DType::U8 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::U32 => *storage.as_cuda_slice::()?.device_ptr(), + candle_core::DType::I16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I32 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I64 => *storage.as_cuda_slice::()?.device_ptr(), diff --git a/mistralrs-core/src/pipeline/exl2.rs b/mistralrs-core/src/pipeline/exl2.rs new file mode 100644 index 000000000..2cef958e7 --- /dev/null +++ b/mistralrs-core/src/pipeline/exl2.rs @@ -0,0 +1,86 @@ +use crate::{ChatTemplate, Loader}; +use anyhow::Result; +use std::sync::Arc; +use tokenizers::Tokenizer; + +use crate::{ + models::quantized_llama::ModelWeights as QLlama, + models::quantized_phi2::ModelWeights as QPhi, + models::quantized_phi3::ModelWeights as QPhi3, + models::quantized_starcoder2::ModelWeights as QStarcoder2, + xlora_models::{XLoraQLlama, XLoraQPhi3}, +}; + +use super::GeneralMetadata; + +enum Model { + Llama(QLlama), + Phi2(QPhi), + XLoraLlama(XLoraQLlama), + XLoraPhi3(XLoraQPhi3), + Phi3(QPhi3), + Starcoder2(QStarcoder2), +} + +pub struct EXL2Pipeline { + model: Model, + tokenizer: Arc, + chat_template: Arc, + model_id: String, + metadata: Arc, +} + +pub struct EXL2Loader { + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + config: EXL2SpecificConfig, +} + +pub struct EXL2SpecificConfig { + pub gpu_split: Option, + pub length: Option, + pub rope_scale: Option, + pub rope_alpha: Option, + pub no_flash_attn: bool, + pub no_xformers: bool, + pub no_sdpa: bool, + pub low_mem: bool, + pub experts_per_token: Option, + pub load_q4: bool, + pub fast_safetensors: bool, + pub ignore_compatibility: bool, + pub chunk_size: Option, +} + +pub struct EXL2LoaderBuilder { + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + topology: Option, + config: EXL2SpecificConfig, +} + +impl EXL2LoaderBuilder { + pub fn new( + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + topology: Option, + config: EXL2SpecificConfig, + ) -> Self { + Self { + tok_model_id, + quantized_model_id, + quantized_filename, + topology, + config: EXL2SpecificConfig { ..config }, + } + } + + pub fn build(self) -> Result> { + // Implementation details for building the EXL2 loader would go here + // This is a placeholder and would need to be filled in with the actual implementation + todo!("Implement EXL2 loader building") + } +} diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 53f873d4f..15fc3e90b 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -270,6 +270,8 @@ pub enum QuantizationKind { Gguf, /// GPTQ Gptq, + /// EXL2 + Exl2, } #[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index dc117bbc4..aefd17495 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -2,6 +2,7 @@ mod amoe; mod cache_manager; pub mod chat_template; mod diffusion; +mod exl2; mod ggml; mod gguf; mod inputs_processor; @@ -16,6 +17,7 @@ mod speculative; mod vision; pub use super::diffusion_models::DiffusionGenerationParams; + use crate::aici::toktree::TokTrie; use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult}; use crate::diffusion_models::response::send_responses; diff --git a/mistralrs-paged-attn/build.rs b/mistralrs-paged-attn/build.rs index 1f640bcdd..02f839e58 100644 --- a/mistralrs-paged-attn/build.rs +++ b/mistralrs-paged-attn/build.rs @@ -36,6 +36,7 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks}; // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } println!("cargo:info={builder:?}"); diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 8890cbfca..3059d62d0 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.0", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index d9e09f1c6..9af6b4832 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -8,6 +8,8 @@ fn main() { println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let lib_files = vec![ + "kernels/exl2/q_gemm_exl2.cu", + "kernels/exl2/q_matrix.cu", "kernels/gptq/q_gemm.cu", "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", @@ -32,6 +34,7 @@ fn main() { // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } diff --git a/mistralrs-quant/kernels/exl2/compat.cuh b/mistralrs-quant/kernels/exl2/compat.cuh new file mode 100644 index 000000000..9e7851c5c --- /dev/null +++ b/mistralrs-quant/kernels/exl2/compat.cuh @@ -0,0 +1,59 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/matrix_view.cuh b/mistralrs-quant/kernels/exl2/matrix_view.cuh new file mode 100644 index 000000000..dd0aebf52 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/matrix_view.cuh @@ -0,0 +1,124 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu new file mode 100644 index 000000000..a867ac017 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu @@ -0,0 +1,103 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ + +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" +#include "q_gemm_kernel.cuh" + +#define MAX_Q_GEMM_ROWS 32 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) +#if defined(USE_ROCM) +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + #define hipblasHgemm __compat_hipblasHgemm +#endif +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +extern "C" void gemm_half_q_half_cuda_part_exl2( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear +) { + dim3 blockDim, gridDim; + blockDim.x = EXL2_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count); + + kernel<<>>( + a, b->cuda_q_weight, b->cuda_q_scale, b->cuda_q_scale_max, c, size_m, + size_n, size_k, b->height, b->groups, b->cuda_q_group_map, + b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3, + b->rows_2, clear); +} + +extern "C" uintptr_t exl2_make_q_matrix( + const int device, + const int height, + const int width, + const int groups, + uint32_t* q_weight, + uint16_t* q_perm, + uint16_t* q_invperm, + uint32_t* q_scale, + half* q_scale_max, + uint16_t* q_groups, + uint16_t* q_group_map +) { + QMatrix* m = new QMatrix + ( + device, + height, + width, + groups, + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + q_group_map + ); + return reinterpret_cast(m); +} + +extern "C" void exl2_reconstruct_q_matrix(uintptr_t q_matrix, half* out) { + QMatrix* m = reinterpret_cast(q_matrix); + m->reconstruct(out); +} + +extern "C" void exl2_destroy_q_matrix(uintptr_t q_matrix) { + QMatrix* m = reinterpret_cast(q_matrix); + delete m; +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh b/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh new file mode 100644 index 000000000..6612dabd1 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh @@ -0,0 +1,556 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#include "compat.cuh" + +#define MAX_Q_GEMM_WEIGHTS 4 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + + +typedef void (*fp_gemm_half_q_half_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const uint16_t*, + const int, + const int, + const int, + const int, + const int, + const int, + const bool +); + +template +__global__ void gemm_half_q_half_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int height, + const int groups, + const uint16_t* __restrict__ b_q_group_map, + const uint16_t* __restrict__ b_q_perm, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2, + const bool clear +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; + + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, height); + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + + // Preload block_a + + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); + + // Preload scales + + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; + + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) + { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = EXL2_BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + // Column result + + half block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[5]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + + a_ptr += 16; + } + k += 16; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) + { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; + } +} + +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) +{ + return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_matrix.cu b/mistralrs-quant/kernels/exl2/q_matrix.cu new file mode 100644 index 000000000..cab969a8e --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_matrix.cu @@ -0,0 +1,435 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ + +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +// Shuffle quantized data on load + +__global__ void shuffle_kernel( + uint32_t *__restrict__ b_q_weight, + const int size_k, + const int size_n, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) + return; + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < rows_8) + { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } + while (k < rows_6) + { + shuffle_6bit_16(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 16; + } + while (k < rows_5) + { + shuffle_5bit_32(b_ptr, size_n); + b_ptr += 5 * size_n; + k += 32; + } + while (k < rows_4) + { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } + while (k < rows_3) + { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } + while (k < rows_2) + { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +// QMatrix constructor + +QMatrix::QMatrix( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t *_q_weight, + uint16_t *_q_perm, + uint16_t *_q_invperm, + uint32_t *_q_scale, + half *_q_scale_max, + uint16_t *_q_groups, + uint16_t *_q_group_map) : device(_device), + height(_height), + width(_width), + groups(_groups) +{ + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + { + uint16_t *cpu_q_groups = (uint16_t *)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + + int row = 0; + for (int i = 0; i < groups; i++) + { + int bits = cpu_q_groups[i * 2]; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else + rows = height - row; + + if (bits == 8) + rows_8 += rows; + if (bits == 6) + rows_6 += rows; + if (bits == 5) + rows_5 += rows; + if (bits == 4) + rows_4 += rows; + if (bits == 3) + rows_3 += rows; + if (bits == 2) + rows_2 += rows; + row += rows; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + + shuffle_kernel<<>>(cuda_q_weight, height, width,rows_8, rows_6, rows_5,rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() {} + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel( + const uint32_t *__restrict__ b_q_weight, + const uint16_t *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_q_scale, + const half *__restrict__ b_q_scale_max, + const uint16_t *__restrict__ b_q_group_map, + const int size_k, + const int size_n, + // const int groupsize, + const int groups, + half *__restrict__ b, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) + return; + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 8; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 16; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + uint32_t q_3 = *b_ptr; + b_ptr += size_n; + uint32_t q_4 = *b_ptr; + b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 32; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 8; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 32; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 16; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 16; + } +} + +void QMatrix::reconstruct(half *out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_kernel<<>>( + cuda_q_weight, + cuda_q_perm, + cuda_q_scale, + cuda_q_scale_max, + cuda_q_group_map, + height, + width, + // groupsize, + groups, + out, + rows_8, + rows_6, + rows_5, + rows_4, + rows_3, + rows_2); + } +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_matrix.cuh b/mistralrs-quant/kernels/exl2/q_matrix.cuh new file mode 100644 index 000000000..6eba6284e --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_matrix.cuh @@ -0,0 +1,72 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int gptq_groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map + ); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh new file mode 100644 index 000000000..d4fdc337a --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh @@ -0,0 +1,78 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh new file mode 100644 index 000000000..b357e020a --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh @@ -0,0 +1,138 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh new file mode 100644 index 000000000..cf1d52d60 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh @@ -0,0 +1,141 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh new file mode 100644 index 000000000..9866fc9b9 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh @@ -0,0 +1,170 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y32, z32); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hadd2( q3.as_half2, z1); + dq[ 4] = __hfma2( q4.as_half2, y32, z32); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hadd2( q6.as_half2, z1); + dq[ 7] = __hfma2( q7.as_half2, y32, z32); + dq[ 8] = __hadd2( q8.as_half2, z1); + dq[ 9] = __hadd2( q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh new file mode 100644 index 000000000..43b2659a1 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh @@ -0,0 +1,36 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" + +__forceinline__ __device__ void shuffle_6bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_6bit_16 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); + dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); + dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh new file mode 100644 index 000000000..807f7fb96 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh @@ -0,0 +1,32 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh new file mode 100644 index 000000000..79a4bf365 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh @@ -0,0 +1,56 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs new file mode 100644 index 000000000..6611a27f8 --- /dev/null +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -0,0 +1,300 @@ +use std::{ + num::NonZeroUsize, + sync::{atomic::AtomicUsize, Arc, Mutex}, +}; + +use candle_core::{ + cuda::{ + cudarc::{ + cublas::{result::hgemm, sys::cublasOperation_t}, + driver::DevicePtr, + }, + WrapErr, + }, + DType, Device, Result, Shape, Tensor, D, +}; +use half::f16; + +use crate::{ + utils::{get_cuda_device, get_cuda_slice}, + IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, +}; + +use super::ffi::{exl2_destroy_q_matrix, exl2_make_q_matrix, exl2_reconstruct_q_matrix}; + +const MAX_Q_GEMM_ROWS: i32 = 32; + +#[derive(Debug)] +pub struct Exl2Layer { + q_weight: Tensor, + q_scale: Tensor, + q_groups: Tensor, + q_invperm: Tensor, + bias: Option, + exllama_state: Arc>, +} + +#[derive(Debug)] +struct ExllamaState { + initialized: bool, + q_scale_max: Tensor, + q_perm: Tensor, + q_invperm_short: Tensor, + q_group_map: Tensor, + q_matrix: *mut std::ffi::c_void, +} + +unsafe impl Send for ExllamaState {} +unsafe impl Sync for ExllamaState {} + +impl Exl2Layer { + fn new( + q_weight: Tensor, + q_scale: Tensor, + q_scale_max: Tensor, + q_groups: Tensor, + q_invperm: Tensor, + bias: Option, + ) -> Result { + let exllama_state = Arc::new(Mutex::new(ExllamaState { + initialized: false, + q_scale_max, + q_perm: Tensor::zeros((1,), DType::I16, q_invperm.device())?, + q_group_map: Tensor::zeros((1,), DType::I16, q_invperm.device())?, + q_invperm_short: Tensor::zeros(q_invperm.shape(), DType::I16, q_invperm.device())?, + q_matrix: std::ptr::null_mut(), + })); + + let this = Self { + q_weight, + q_scale, + q_groups, + q_invperm, + bias, + exllama_state, + }; + this.initialize_exllama()?; + Ok(this) + } + + fn initialize_exllama(&self) -> Result<()> { + let mut state = self.exllama_state.lock().unwrap(); + if state.initialized { + return Ok(()); + } + + let dev = get_cuda_device(&self.q_weight)?; + + state.q_scale_max = (state.q_scale_max.clone() / 256.0)?; + state.q_invperm_short = self.q_invperm.to_dtype(DType::I16)?; + state.q_perm = state + .q_invperm_short + .arg_sort_last_dim(false)? + .to_dtype(DType::I16)?; + state.q_group_map = make_group_map(&self.q_groups, self.q_weight.dim(0)?)?; + + let dev_ord = dev.ordinal() as i32; + let b_width = self.q_weight.dims()[1] as i32; + let b_height = state.q_perm.dims()[0] as i32; + let b_groups = self.q_scale.dims()[0] as i32; + let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; + let b_q_perm = get_cuda_slice::(&state.q_perm)? as *const u16; + let b_q_invperm = get_cuda_slice::(&self.q_invperm)? as *const u16; + let b_q_scale = get_cuda_slice::(&self.q_scale)? as *const u32; + let b_q_scale_max = get_cuda_slice::(&state.q_scale_max)?; + let b_q_groups = get_cuda_slice::(&self.q_groups)? as *const u16; + let b_q_group_map = get_cuda_slice::(&state.q_group_map)? as *const u16; + + state.q_matrix = unsafe { + exl2_make_q_matrix( + dev_ord, + b_height, + b_width, + b_groups, + b_q_weight, + b_q_perm, + b_q_invperm, + b_q_scale, + b_q_scale_max, + b_q_groups, + b_q_group_map, + ) + }; + + state.initialized = true; + Ok(()) + } + + fn exl2_gemm(&self, a: Tensor) -> Result { + self.initialize_exllama()?; + + let dev = get_cuda_device(&a)?; + let a_ptr = get_cuda_slice::(&a)?; + + let qm_width = self.q_weight.dim(1)?; + let c_shape = Shape::from_dims(&[a.dims()[0], qm_width]); + + let (m, n, k) = ( + c_shape.dims()[0] as i32, + c_shape.dims()[1] as i32, + a.dims()[1] as i32, + ); + + let c = unsafe { dev.alloc::(c_shape.elem_count()).w()? }; + let c_ptr = *c.device_ptr() as *mut f16; + + if m > MAX_Q_GEMM_ROWS { + let temp_dq = if m > MAX_Q_GEMM_ROWS { + Tensor::zeros(&[k as usize, n as usize], DType::F16, a.device())? + } else { + Tensor::zeros(&[0, 0], DType::F16, a.device())? + }; + let temp_dq_ptr = get_cuda_slice::(&temp_dq)? as *mut f16; + + // Reconstruct FP16 matrix, then cuBLAS + unsafe { + exl2_reconstruct_q_matrix(self.exllama_state.lock().unwrap().q_matrix, temp_dq_ptr); + } + + let alpha = f16::from_f32(1.0); + let beta = f16::from_f32(0.0); + let cublas_handle = match a.device() { + Device::Cuda(dev) => dev.cublas_handle(), + _ => unreachable!(), // invariant enforced earlier + }; + + unsafe { + hgemm( + *cublas_handle.handle(), + cublasOperation_t::CUBLAS_OP_N, + cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + &alpha, + temp_dq_ptr as *const _, + n, + a_ptr as *const _, + k, + &beta, + c_ptr, + n, + ) + .w()? + }; + } else { + todo!() + } + todo!() + } +} + +fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { + let gr = q_groups.to_vec1::()?; + let mut group_map = Vec::new(); + let num_groups = gr.len() / 2; + + for i in 0..num_groups { + let bits = gr[i * 2] as usize; + let qrows = if i < num_groups - 1 { + gr[i * 2 + 3] as usize - gr[i * 2 + 1] as usize + } else { + num_qrows - gr[i * 2 + 1] as usize + }; + let rows = qrows * 32 / bits; + for j in 0..rows { + group_map.push(i as i16); + group_map.push((rows - j) as i16); + } + } + + Tensor::from_vec(group_map.clone(), (group_map.len(),), q_groups.device()) +} + +impl QuantMethod for Exl2Layer { + fn new(method: QuantMethodConfig) -> Result { + match method { + QuantMethodConfig::Exl2 { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_invperm, + bias, + } => Self::new(q_weight, q_scale, q_scale_max, q_groups, q_invperm, bias), + QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Hqq { .. } => { + unreachable!() + } + } + } + + fn forward(&self, x: &Tensor) -> Result { + let out_shape = Shape::from_dims( + &[ + &x.dims()[..x.dims().len() - 1], + &[self.q_weight.dim(D::Minus1)?], + ] + .concat(), + ); + let reshaped_x = x.reshape(((), x.dim(D::Minus1)?))?; + let mut output = self.exl2_gemm(reshaped_x)?; + if let Some(bias) = &self.bias { + output = output.broadcast_add(bias)?; + } + output.reshape(out_shape) + } + + fn quantized_act_type(&self) -> Option { + Some(DType::F16) + } + + fn add_delta_w(&self, _delta: &Tensor) -> Result> { + candle_core::bail!("EXL2 quantization does not support adding weight delta.") + } + + fn dtype_and_device(&self) -> (DType, Device) { + todo!() + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + None + } + + fn apply_isq( + self: Arc, + _dtype: Option, + _device: Device, + _n_quantized: &AtomicUsize, + ) -> Result> { + candle_core::bail!("EXL2 quantization does not support ISQ.") + } + + fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { + None + } +} + +impl Drop for Exl2Layer { + fn drop(&mut self) { + if let Ok(mut state) = self.exllama_state.lock() { + if !state.q_matrix.is_null() { + unsafe { + exl2_destroy_q_matrix(state.q_matrix); + } + state.q_matrix = std::ptr::null_mut(); + } + } + } +} + +impl QuantizedSerde for Exl2Layer { + fn isq_serde_supported(&self) -> bool { + false + } + fn name(&self) -> &'static str { + "exl2" + } +} diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs new file mode 100644 index 000000000..ec431d930 --- /dev/null +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -0,0 +1,28 @@ +use half::f16; +use std::ffi::c_void; + +// Opaque pointer type for QMatrix +type QMatrixPtr = *mut c_void; + +#[allow(dead_code)] +extern "C" { + pub fn exl2_make_q_matrix( + device: i32, + height: i32, // q_perm.size(0); + width: i32, // q_weight.size(1); + groups: i32, // q_scale.size(0); + q_weight: *const u32, + q_perm: *const u16, + q_invperm: *const u16, + q_scale: *const u32, + q_scale_max: *const f16, + q_groups: *const u16, + q_group_map: *const u16, + ) -> QMatrixPtr; + + pub fn exl2_destroy_q_matrix(q_matrix: QMatrixPtr); + + pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr, out: *mut f16); + + pub fn exl2_gemm_cuda(a: *const f16, b: *const c_void, c: *mut f16, m: i32, n: i32, k: i32); +} diff --git a/mistralrs-quant/src/exl2/mod.rs b/mistralrs-quant/src/exl2/mod.rs new file mode 100644 index 000000000..309bcebf2 --- /dev/null +++ b/mistralrs-quant/src/exl2/mod.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "cuda")] +mod exl2_cuda; +#[cfg(feature = "cuda")] +mod ffi; + +#[cfg(feature = "cuda")] +pub use exl2_cuda::Exl2Layer; diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 4a7f5ec67..e8ae2f969 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -34,7 +34,8 @@ impl QuantMethod for GgufMatMul { w: QMatMul::from_arc(q_weight)?, b, }), - QuantMethodConfig::Gptq { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => unreachable!(), } diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index 9d844faa2..0f42d8af4 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -23,7 +23,8 @@ impl QuantMethod for GptqLayer { g_idx: _, bias: _, } => candle_core::bail!("GPTQ is only supported on CUDA."), - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index f3740aa50..b8ccf92a1 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -237,7 +237,8 @@ impl QuantMethod for GptqLayer { bias, }) } - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index 8c2236195..90ee8016f 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -525,7 +525,8 @@ impl QuantMethod for HqqLayer { Self: Sized, { match method { - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Gptq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 5cbd6e933..5544975dc 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -7,15 +7,17 @@ use std::{ use candle_core::{ quantized::{GgmlDType, QTensor}, - DType, Device, Result, Tensor, + Context, DType, Device, Result, Tensor, }; +mod exl2; mod gguf; mod gptq; mod hqq; mod unquantized; mod utils; +use exl2::Exl2Layer; pub use gguf::GgufMatMul; pub use gptq::GptqLayer; pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; @@ -24,30 +26,40 @@ pub use unquantized::UnquantLinear; use candle_nn::{Linear, VarBuilder}; use serde::Deserialize; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize)] pub enum QuantMethodType { - #[default] #[serde(rename = "gptq")] Gptq, + #[serde(rename = "exl2")] + Exl2, } impl Display for QuantMethodType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Gptq => write!(f, "GPTQ"), + Self::Exl2 => write!(f, "EXL2"), } } } -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize)] pub struct QuantizedConfig { pub bits: usize, pub quant_method: QuantMethodType, - pub group_size: usize, + pub group_size: Option, } #[derive(Debug, Clone)] pub enum QuantMethodConfig { + Exl2 { + q_weight: Tensor, + q_scale: Tensor, + q_scale_max: Tensor, + q_groups: Tensor, + q_invperm: Tensor, + bias: Option, + }, Gptq { bits: i32, use_exllama: bool, @@ -223,6 +235,7 @@ pub fn linear_no_bias( let layer = if let Some(quant_conf) = &config { match quant_conf.quant_method { QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + QuantMethodType::Exl2 => todo!(), } } else { let layer = candle_nn::linear_no_bias(in_dim, out_dim, vb)?; @@ -242,6 +255,7 @@ pub fn linear( let layer = if let Some(quant_conf) = &config { match quant_conf.quant_method { QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + QuantMethodType::Exl2 => todo!(), } } else { let layer = candle_nn::linear(in_dim, out_dim, vb)?; @@ -278,7 +292,10 @@ pub fn gptq_linear( Default::default(), DType::I32, )?; - let scale_and_zero_size = in_dim / config.group_size; + let scale_and_zero_size = in_dim + / config + .group_size + .context("GPTQ requires group size in QuantizedConfig")?; let qzeros = vb.get_with_hints_dtype( (scale_and_zero_size, out_dim / pack_factor!(config.bits)), "qzeros", @@ -305,3 +322,24 @@ pub fn gptq_linear( }; Ok(Arc::new(GptqLayer::new(config)?)) } + +pub fn exl2_linear( + _in_dim: usize, + _out_dim: usize, + _config: &QuantizedConfig, + vb: VarBuilder, +) -> Result> { + let q_weight = vb.get_unchecked_dtype("q_weight", DType::I32)?; + let q_scale_max = vb.get_unchecked_dtype("q_scale_max", DType::F16)?; + let q_scale = vb.get_unchecked_dtype("q_scale", DType::I32)?; + let q_invperm = vb.get_unchecked_dtype("q_invperm", DType::I32)?; + let q_groups = vb.get_unchecked_dtype("q_groups", DType::I16)?; + Ok(Arc::new(Exl2Layer::new(QuantMethodConfig::Exl2 { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_invperm, + bias: None, + })?)) +} diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index e44b7ea92..d8d587205 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -25,7 +25,8 @@ impl QuantMethod for UnquantLinear { Self: Sized, { match method { - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Hqq { .. } => unreachable!(), QuantMethodConfig::Unquantized(l) => Ok(Self(l)), diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index 1494357aa..e9ca1e573 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -116,9 +116,9 @@ pub(crate) fn deserialize_tensor( DType::BF16 => bytes_to_data::(&tensor_data, &dims, device), DType::F32 => bytes_to_data::(&tensor_data, &dims, device), DType::F64 => bytes_to_data::(&tensor_data, &dims, device), + DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::I32 => bytes_to_data::(&tensor_data, &dims, device), DType::I64 => bytes_to_data::(&tensor_data, &dims, device), - DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::U32 => bytes_to_data::(&tensor_data, &dims, device), DType::U8 => bytes_to_data::(&tensor_data, &dims, device), }