Skip to content

Commit

Permalink
Making CUDA backend as a feature (prepared for Metal devices) (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao authored Dec 30, 2024
1 parent 6358cb5 commit 8da81eb
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 144 deletions.
139 changes: 136 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ tracing = "0.1.40"
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }
either = { version = "1.13.0", features = ["serde"] }
dirs = "5.0.1"
kernels = {path = "./kernels", version="0.1.0"}
kernels = {path = "./kernels", version="0.1.0", optional = true}
#metal-kernels = {path = "./metal-kernels", version="0.1.0", optional = true}

[features]
default = ["cuda"]
accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:kernels"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle-core/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
Expand Down
2 changes: 1 addition & 1 deletion src/backend/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ pub unsafe fn copy_blocks(
COPY_BLOCKS_KERNEL_NAME,
key_caches.first().unwrap().dtype(),
None,
dev,
&dev,
));

try_api!(unsafe {
Expand Down
13 changes: 9 additions & 4 deletions src/backend/gptq.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor};
use candle_core as candle;
use half::{bf16, f16};
#[cfg(feature = "cuda")]
use kernels::ffi::{gemm_half_q_half_alt, gptq_repack, marlin_4bit_bf16, marlin_4bit_f16};

struct GPTQMatMul {
Expand All @@ -14,6 +13,7 @@ struct GPTQMatMul {
}

impl GPTQMatMul {
#[cfg(feature = "cuda")]
fn cuda_fwd_t<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
Expand All @@ -25,6 +25,8 @@ impl GPTQMatMul {
scale: &CudaStorage,
scale_l: &Layout,
) -> Result<(CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
let dev = qweight.device();
let x_shape = x_l.dims();
let weight_shape = qweight_l.dims();
Expand Down Expand Up @@ -169,7 +171,7 @@ impl candle::CustomOp3 for GPTQMatMul {
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for GPTQMatMul")
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
x: &CudaStorage,
Expand Down Expand Up @@ -210,13 +212,16 @@ struct GPTQRepack {
}

impl GPTQRepack {
#[cfg(feature = "cuda")]
fn cuda_fwd_t<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
qweight: &CudaStorage,
qweight_l: &Layout,
) -> Result<(CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
let dev = qweight.device();
let q_shape = qweight_l.dims();
let mut out_shape: Vec<usize> = q_shape.to_vec();
Expand Down Expand Up @@ -252,7 +257,7 @@ impl candle::CustomOp1 for GPTQRepack {
fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for GPTQRepack")
}

#[cfg(feature = "cuda")]
fn cuda_fwd(&self, qweight: &CudaStorage, qweight_l: &Layout) -> Result<(CudaStorage, Shape)> {
match qweight.dtype() {
DType::U32 => self.cuda_fwd_t::<u32>(qweight, qweight_l),
Expand Down
Loading

0 comments on commit 8da81eb

Please sign in to comment.