diff --git a/.github/workflows/cargo-check-features.yml b/.github/workflows/cargo-check-features.yml index fc86b324d..34f92d136 100644 --- a/.github/workflows/cargo-check-features.yml +++ b/.github/workflows/cargo-check-features.yml @@ -9,9 +9,9 @@ jobs: matrix: config: - toolchain: stable - command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cpu-mkl-matmul,cuda,test-cuda + command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cpu-mkl-matmul,cuda,cudnn - toolchain: nightly - command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cpu-mkl-matmul,cuda,test-cuda + command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cpu-mkl-matmul,cuda,cudnn steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/cargo-check.yml b/.github/workflows/cargo-check.yml index 192ae0c24..1030b6cef 100644 --- a/.github/workflows/cargo-check.yml +++ b/.github/workflows/cargo-check.yml @@ -21,4 +21,9 @@ jobs: uses: actions-rs/cargo@v1 with: command: check - args: --features test-cuda,ci-check + args: --features cuda,ci-check + - name: Check CUDNN + uses: actions-rs/cargo@v1 + with: + command: check + args: --features cudnn,ci-check diff --git a/Cargo.toml b/Cargo.toml index de5cdc442..7b2e0efca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ matrixmultiply = { version = "0.3.2", default-features = false, optional = true zip = { version = "0.6.2", default-features = false, optional = true } cblas-sys = { version = "0.1.4", default-features = false, optional = true } libc = { version = "0.2", default-features = false, optional = true } -cudarc = { version = "0.9.5", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } +cudarc = { version = "0.9.6", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } num-traits = { version = "0.2.15", default-features = false } safetensors = { version = "0.3", default-features = false, optional = true } memmap2 = { version = "0.5", default-features = false, optional = true } @@ -56,12 +56,13 @@ no-std = ["no-std-compat", "dep:spin", "cudarc?/no-std"] cpu-seq-matmul = ["dep:matrixmultiply"] cpu-par-matmul = ["std", "dep:matrixmultiply", "matrixmultiply?/threading"] cpu-mkl-matmul = ["dep:cblas-sys", "dep:libc"] + cuda = ["dep:cudarc", "dep:glob"] +cudnn = ["cuda", "cudarc?/cudnn"] numpy = ["dep:zip", "std"] safetensors = ["dep:safetensors", "std", "dep:memmap2"] -test-cuda = ["cuda"] test-f64 = [] test-integrations = [] ci-check = ["cudarc?/ci-check"] diff --git a/src/lib.rs b/src/lib.rs index 238778120..9974b2737 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -242,10 +242,10 @@ pub fn keep_denormals() { #[cfg(test)] pub(crate) mod tests { - #[cfg(not(feature = "test-cuda"))] + #[cfg(not(feature = "cuda"))] pub type TestDevice = crate::tensor::Cpu; - #[cfg(feature = "test-cuda")] + #[cfg(feature = "cuda")] pub type TestDevice = crate::tensor::Cuda; #[cfg(not(feature = "test-f64"))] diff --git a/src/tensor/cuda/device.rs b/src/tensor/cuda/device.rs index 230bfbccc..3b46c8839 100644 --- a/src/tensor/cuda/device.rs +++ b/src/tensor/cuda/device.rs @@ -20,6 +20,8 @@ pub struct Cuda { pub(crate) cpu: Cpu, pub(crate) dev: Arc, pub(crate) blas: Arc, + #[cfg(feature = "cudnn")] + pub(crate) cudnn: Arc, /// A second stream for kernels to optionally execute on. pub(crate) par_stream: Arc, pub(crate) workspace: Arc>>, @@ -28,6 +30,8 @@ pub struct Cuda { #[derive(Debug)] pub enum CudaError { Blas(CublasError), + #[cfg(feature = "cudnn")] + Cudnn(cudarc::cudnn::CudnnError), Driver(DriverError), Cpu(CpuError), } @@ -50,6 +54,13 @@ impl From for CudaError { } } +#[cfg(feature = "cudnn")] +impl From for CudaError { + fn from(value: cudarc::cudnn::CudnnError) -> Self { + Self::Cudnn(value) + } +} + impl Default for Cuda { fn default() -> Self { Self::seed_from_u64(0) @@ -72,12 +83,16 @@ impl Cuda { let cpu = Cpu::seed_from_u64(seed); let dev = CudaDevice::new(ordinal)?; let blas = Arc::new(CudaBlas::new(dev.clone())?); + #[cfg(feature = "cudnn")] + let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?; let par_stream = Arc::new(dev.fork_default_stream()?); let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(0)?)); Ok(Self { cpu, dev, blas, + #[cfg(feature = "cudnn")] + cudnn, par_stream, workspace, }) diff --git a/src/tensor_ops/conv2d/cudnn_kernel.rs b/src/tensor_ops/conv2d/cudnn_kernel.rs new file mode 100644 index 000000000..48dfeec6b --- /dev/null +++ b/src/tensor_ops/conv2d/cudnn_kernel.rs @@ -0,0 +1,177 @@ +use cudarc::cudnn::{self, Conv2dBackwardData, Conv2dBackwardFilter, Conv2dForward, CudnnDataType}; +use cudarc::driver::DeviceSlice; + +use crate::{ + shapes::*, + tensor::{unique_id, Cuda, GhostTensor, Tensor}, +}; + +use std::sync::Arc; + +trait HasCudnnKernel {} +impl HasCudnnKernel for Cuda {} +impl HasCudnnKernel for Cuda {} + +fn make_4d(strides: S::Concrete, pad: usize) -> [usize; 4] { + match S::NUM_DIMS { + 3 => [pad, strides[0], strides[1], strides[2]], + 4 => [strides[0], strides[1], strides[2], strides[3]], + _ => unreachable!("Only implemented for 3d & 4d arrays"), + } +} + +impl super::Conv2DKernel for Cuda +where + Self: HasCudnnKernel, +{ + fn alloc(&self, shape: S) -> Result, Self::Err> { + let data = Arc::new(unsafe { self.dev.alloc::(shape.num_elements()) }?); + Ok(Tensor { + id: unique_id(), + data, + shape, + strides: shape.strides(), + device: self.clone(), + tape: Default::default(), + }) + } + fn forward( + &self, + op: super::Conv2DOp, + lhs: &Tensor, + rhs: &Tensor, + out: &mut Tensor, + ) -> Result<(), Self::Err> { + let conv = self.cudnn.create_conv2d::( + [op.padding as i32, op.padding as i32], + [op.stride as i32, op.stride as i32], + [1, 1], + cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + )?; + let img = self.cudnn.create_4d_tensor_ex::( + make_4d::(lhs.shape.concrete(), 1).map(|x| x as i32), + make_4d::(lhs.strides, 0).map(|x| x as i32), + )?; + let filter = self.cudnn.create_4d_filter::( + cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + make_4d::(rhs.shape.concrete(), 1).map(|x| x as i32), + )?; + let y = self.cudnn.create_4d_tensor_ex::( + make_4d::(out.shape.concrete(), 1).map(|x| x as i32), + make_4d::(out.strides, 0).map(|x| x as i32), + )?; + let op = Conv2dForward { + conv: &conv, + x: &img, + w: &filter, + y: &y, + }; + + let algo = op.pick_algorithm()?; + let workspace_size_in_bytes = op.get_workspace_size(algo)?; + + unsafe { + let mut workspace = self.get_workspace::(workspace_size_in_bytes)?; + let mut workspace = workspace + .transmute_mut::(workspace_size_in_bytes) + .unwrap(); + assert_eq!(workspace.len(), workspace_size_in_bytes); + op.launch( + algo, + Some(&mut workspace), + (E::ONE, Default::default()), + lhs.data.as_ref(), + rhs.data.as_ref(), + Arc::get_mut(&mut out.data).unwrap(), + )?; + } + + Ok(()) + } + + fn backward( + &self, + op: super::Conv2DOp, + lhs: &Tensor, + grad_lhs: &mut Self::Vec, + rhs: &Tensor, + grad_rhs: &mut Self::Vec, + out: &GhostTensor, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let conv = self.cudnn.create_conv2d::( + [op.padding as i32, op.padding as i32], + [op.stride as i32, op.stride as i32], + [1, 1], + cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + )?; + let img = self.cudnn.create_4d_tensor_ex::( + make_4d::(lhs.shape.concrete(), 1).map(|x| x as i32), + make_4d::(lhs.strides, 0).map(|x| x as i32), + )?; + let filter = self.cudnn.create_4d_filter::( + cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + make_4d::(rhs.shape.concrete(), 1).map(|x| x as i32), + )?; + let out = self.cudnn.create_4d_tensor_ex::( + make_4d::(out.shape.concrete(), 1).map(|x| x as i32), + make_4d::(out.strides, 0).map(|x| x as i32), + )?; + + { + let op = Conv2dBackwardData { + conv: &conv, + dx: &img, + w: &filter, + dy: &out, + }; + let algo = op.pick_algorithm()?; + let workspace_size_in_bytes = op.get_workspace_size(algo)?; + + unsafe { + let mut workspace = self.get_workspace::(workspace_size_in_bytes)?; + let mut workspace = workspace + .transmute_mut::(workspace_size_in_bytes) + .unwrap(); + assert_eq!(workspace.len(), workspace_size_in_bytes); + op.launch( + algo, + Some(&mut workspace), + (E::ONE, Default::default()), + grad_lhs, + rhs.data.as_ref(), + grad_out, + ) + }?; + } + + { + let op = Conv2dBackwardFilter { + conv: &conv, + x: &img, + dw: &filter, + dy: &out, + }; + + let algo = op.pick_algorithm()?; + let workspace_size_in_bytes = op.get_workspace_size(algo)?; + + unsafe { + let mut workspace = self.get_workspace::(workspace_size_in_bytes)?; + let mut workspace = workspace + .transmute_mut::(workspace_size_in_bytes) + .unwrap(); + assert_eq!(workspace.len(), workspace_size_in_bytes); + op.launch( + algo, + Some(&mut workspace), + (E::ONE, Default::default()), + lhs.data.as_ref(), + grad_rhs, + grad_out, + ) + }?; + } + Ok(()) + } +} diff --git a/src/tensor_ops/conv2d/mod.rs b/src/tensor_ops/conv2d/mod.rs index f1f67db52..e89cb0735 100644 --- a/src/tensor_ops/conv2d/mod.rs +++ b/src/tensor_ops/conv2d/mod.rs @@ -1,8 +1,11 @@ mod cpu_kernel; -#[cfg(feature = "cuda")] +#[cfg(all(not(feature = "cudnn"), feature = "cuda"))] mod cuda_kernel; +#[cfg(feature = "cudnn")] +mod cudnn_kernel; + use crate::{shapes::*, tensor::*}; #[repr(C)] @@ -226,6 +229,7 @@ impl< mod tests { use super::*; use crate::{tensor_ops::*, tests::*}; + use num_traits::FromPrimitive; #[test] /// Produced by @@ -434,6 +438,7 @@ mod tests { let x = x .broadcast::, _>() .reshape::>(); + assert_eq!(x.strides, x.shape.strides()); let y: Tensor, _, _, _> = x.leaky_trace().conv2d::<3, 2>(w.clone()); for i in 0..10 { @@ -442,7 +447,7 @@ mod tests { let grads = y.square().mean().backward(); - assert_close(&w0, &(grads.get(&w)).array()); + w0.assert_close(&(grads.get(&w)).array(), TestDtype::from_f32(1e-3).unwrap()); let x_grad = grads.get(&x) * 10.0; for i in 0..10 { diff --git a/src/tensor_ops/select_and_gather/mod.rs b/src/tensor_ops/select_and_gather/mod.rs index a8d6a6c3c..caf32fd92 100644 --- a/src/tensor_ops/select_and_gather/mod.rs +++ b/src/tensor_ops/select_and_gather/mod.rs @@ -214,7 +214,7 @@ mod tests { let _ = t.leaky_trace().select(dev.zeros_like(&(7, 4))); } - #[cfg(not(feature = "test-cuda"))] + #[cfg(not(feature = "cuda"))] #[test] #[should_panic = "Index out of bounds: index=[7]"] fn test_select_index_out_of_bounds() { @@ -241,7 +241,7 @@ mod tests { let _ = t.leaky_trace().gather(dev.zeros_like(&(5, 4, 2))); } - #[cfg(not(feature = "test-cuda"))] + #[cfg(not(feature = "cuda"))] #[test] #[should_panic = "Index out of bounds: index=[7]"] fn test_gather_index_out_of_bounds() { @@ -250,7 +250,7 @@ mod tests { let _ = t.leaky_trace().gather(dev.tensor([7, 6, 1, 2])); } - #[cfg(not(feature = "test-cuda"))] + #[cfg(not(feature = "cuda"))] #[test] #[should_panic = "Index out of bounds: index=[5, 0]"] fn test_gather_batch_out_of_bounds() {